Pull out indices creation for kv cache update
This commit is contained in:
parent
bf5414c95e
commit
acf358d1ce
|
|
@ -71,6 +71,9 @@ public:
|
|||
}
|
||||
|
||||
Output<Node> get_input(const std::string& name) const override {
|
||||
if (m_tensor_map->find(name) == m_tensor_map->end()) {
|
||||
throw std::runtime_error("'" + name + "' not found in tensor map.");
|
||||
}
|
||||
return m_tensor_map->at(name);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -4,19 +4,11 @@
|
|||
#include <openvino/core/node.hpp>
|
||||
#include <openvino/core/node_output.hpp>
|
||||
#include <openvino/core/node_vector.hpp>
|
||||
#include <openvino/op/add.hpp>
|
||||
#include <openvino/op/broadcast.hpp>
|
||||
#include <openvino/op/concat.hpp>
|
||||
#include <openvino/op/constant.hpp>
|
||||
#include <openvino/op/convert.hpp>
|
||||
#include <openvino/op/range.hpp>
|
||||
#include <openvino/op/reshape.hpp>
|
||||
#include <openvino/op/scatter_nd_update.hpp>
|
||||
#include <openvino/op/shape_of.hpp>
|
||||
#include <openvino/op/slice.hpp>
|
||||
#include <openvino/op/squeeze.hpp>
|
||||
#include <openvino/op/transpose.hpp>
|
||||
#include <openvino/op/unsqueeze.hpp>
|
||||
#include <vector>
|
||||
|
||||
#include "../node_context.hpp"
|
||||
|
|
@ -36,8 +28,13 @@ OutputVector translate_cpy(const NodeContext& context) {
|
|||
|
||||
auto src0 = context.get_input(0);
|
||||
auto src1 = context.get_input(1);
|
||||
auto token_len = context.get_input("token_len");
|
||||
auto past_token_len = context.get_input("past_token_len");
|
||||
|
||||
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
auto token_len_scalar = std::make_shared<ov::op::v0::Squeeze>(token_len, zero);
|
||||
auto past_token_len_scalar = std::make_shared<ov::op::v0::Squeeze>(past_token_len, zero);
|
||||
|
||||
src0 = std::make_shared<ov::op::v0::Convert>(src0, context.get_input_type(1));
|
||||
ov::Output<Node> res;
|
||||
|
||||
|
|
@ -52,89 +49,24 @@ OutputVector translate_cpy(const NodeContext& context) {
|
|||
std::vector<size_t> input0_strides = context.get_input_stride(0);
|
||||
std::vector<size_t> output_strides = context.get_output_stride(0);
|
||||
|
||||
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||
auto one_scalar = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {1});
|
||||
|
||||
if (op_case == 1) {
|
||||
// Write K to cache_k
|
||||
auto token_len = get_dimensions(src0.get_node_shared_ptr(), {0});
|
||||
auto token_len_scalar = std::make_shared<ov::op::v0::Squeeze>(token_len, zero);
|
||||
|
||||
std::shared_ptr<ov::Node> indices;
|
||||
if (context.is_static()) {
|
||||
indices = past_token_len.get_node_shared_ptr();
|
||||
} else {
|
||||
auto past_token_len_scalar = std::make_shared<ov::op::v0::Squeeze>(past_token_len, zero);
|
||||
auto total_token_len_scalar = std::make_shared<ov::op::v1::Add>(past_token_len_scalar, token_len_scalar);
|
||||
indices = std::make_shared<ov::op::v4::Range>(past_token_len_scalar,
|
||||
total_token_len_scalar,
|
||||
one_scalar,
|
||||
ov::element::i64);
|
||||
}
|
||||
indices = std::make_shared<ov::op::v0::Unsqueeze>(indices, one);
|
||||
|
||||
auto indices = context.get_input("update_indices_k");
|
||||
auto updated = std::make_shared<ov::op::v3::ScatterNDUpdate>(src1, indices, src0);
|
||||
res = std::make_shared<ov::op::v1::Reshape>(updated, std::make_shared<ov::op::v0::ShapeOf>(src1), false);
|
||||
} else {
|
||||
// Write V to cache_v
|
||||
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||
auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
|
||||
auto zero_scalar = ov::op::v0::Constant::create(ov::element::i64, {}, {0});
|
||||
|
||||
int64_t total_head_size = src0_shape[1];
|
||||
auto total_head_size_node = ov::op::v0::Constant::create(ov::element::i64, {1}, {total_head_size});
|
||||
auto total_head_size_scalar = std::make_shared<ov::op::v0::Squeeze>(total_head_size_node, zero);
|
||||
|
||||
auto token_len = get_dimensions(src0.get_node_shared_ptr(), {2});
|
||||
auto token_len_scalar = std::make_shared<ov::op::v0::Squeeze>(token_len, zero);
|
||||
|
||||
// 1D tensor of shape [total_head_size], values starting from 0
|
||||
auto range_row =
|
||||
std::make_shared<ov::op::v4::Range>(zero_scalar, total_head_size_scalar, one_scalar, ov::element::i64);
|
||||
auto range_row_reshaped =
|
||||
std::make_shared<ov::op::v0::Unsqueeze>(range_row,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {2}, {1, 2}));
|
||||
auto row_indices = std::make_shared<ov::op::v3::Broadcast>(
|
||||
range_row_reshaped,
|
||||
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{total_head_size_node, token_len, one}, 0));
|
||||
|
||||
// 1D tensor of shape [token_len], values starting from past_token_len
|
||||
std::shared_ptr<ov::Node> range_col;
|
||||
if (context.is_static()) {
|
||||
range_col = past_token_len.get_node_shared_ptr();
|
||||
} else {
|
||||
auto past_token_len_scalar = std::make_shared<ov::op::v0::Squeeze>(past_token_len, zero);
|
||||
auto total_token_len_scalar = std::make_shared<ov::op::v1::Add>(past_token_len_scalar, token_len_scalar);
|
||||
range_col = std::make_shared<ov::op::v4::Range>(past_token_len_scalar,
|
||||
total_token_len_scalar,
|
||||
one_scalar,
|
||||
ov::element::i64);
|
||||
}
|
||||
auto range_col_reshaped =
|
||||
std::make_shared<ov::op::v0::Unsqueeze>(range_col,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 2}));
|
||||
auto col_indices = std::make_shared<ov::op::v3::Broadcast>(
|
||||
range_col_reshaped,
|
||||
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{total_head_size_node, token_len, one}, 0));
|
||||
|
||||
// Stack row_indices and col_indices along last axis: [total_head_size, token_len, 2]
|
||||
auto indices = std::make_shared<ov::op::v0::Concat>(OutputVector{row_indices, col_indices}, 2);
|
||||
auto indices_final = std::make_shared<ov::op::v1::Reshape>(
|
||||
indices,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {2}, std::vector<int64_t>{-1, 2}),
|
||||
false);
|
||||
|
||||
auto flattend_src0 =
|
||||
std::make_shared<ov::op::v1::Reshape>(src0,
|
||||
ov::op::v0::Constant::create(element::i64, Shape{1}, {-1}),
|
||||
false);
|
||||
int64_t total_head_size = src0_shape[1];
|
||||
auto reshaped_src1 = std::make_shared<ov::op::v1::Reshape>(
|
||||
src1,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {2}, std::vector<int64_t>{total_head_size, -1}),
|
||||
false);
|
||||
|
||||
auto updated = std::make_shared<ov::op::v3::ScatterNDUpdate>(reshaped_src1, indices_final, flattend_src0);
|
||||
auto indices = context.get_input("update_indices_v");
|
||||
auto updated = std::make_shared<ov::op::v3::ScatterNDUpdate>(reshaped_src1, indices, flattend_src0);
|
||||
res = std::make_shared<ov::op::v1::Reshape>(updated, std::make_shared<ov::op::v0::ShapeOf>(src1), false);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -3,11 +3,20 @@
|
|||
#include <cstdlib>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <openvino/core/node.hpp>
|
||||
#include <openvino/op/broadcast.hpp>
|
||||
#include <openvino/op/concat.hpp>
|
||||
#include <openvino/op/parameter.hpp>
|
||||
#include <openvino/op/range.hpp>
|
||||
#include <openvino/op/reshape.hpp>
|
||||
#include <openvino/op/result.hpp>
|
||||
#include <openvino/op/squeeze.hpp>
|
||||
#include <openvino/op/unsqueeze.hpp>
|
||||
#include <openvino/pass/constant_folding.hpp>
|
||||
#include <openvino/pass/make_stateful.hpp>
|
||||
|
||||
#include "ggml-openvino/openvino/node_context.hpp"
|
||||
#include "ggml-openvino/openvino/utils.hpp"
|
||||
#include "input_model.hpp"
|
||||
#include "pass/fuse_to_sdpa.hpp"
|
||||
|
||||
|
|
@ -50,6 +59,83 @@ ov::pass::MakeStateful::ParamResPairs get_kv_param_res_pairs(
|
|||
}
|
||||
return pairs;
|
||||
}
|
||||
|
||||
void add_token_len(TensorMap& tensor_map) {
|
||||
auto inp_tokens = tensor_map.at("inp_tokens").get_node_shared_ptr();
|
||||
auto token_len = get_dimensions(inp_tokens, {2});
|
||||
token_len->set_friendly_name("token_len");
|
||||
tensor_map.insert({"token_len", token_len->output(0)});
|
||||
}
|
||||
|
||||
void add_kv_update_indices(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
|
||||
// cache_k layout: [S, N, H] (seq, num_heads, head_size)
|
||||
// cache_v layout: [N, H, S] (num_heads, head_size, seq)
|
||||
// When writing to cache_v, cache should be reshaped to [N*H, S] and v-curr should be flattened
|
||||
auto inp_pos = tensor_map.at("inp_pos").get_node_shared_ptr();
|
||||
auto past_token_len = tensor_map.at("past_token_len").get_node_shared_ptr();
|
||||
auto token_len = tensor_map.at("token_len").get_node_shared_ptr();
|
||||
|
||||
std::shared_ptr<ov::Node> update_indices_k;
|
||||
std::shared_ptr<ov::Node> update_indices_v;
|
||||
|
||||
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
auto zero_scalar = ov::op::v0::Constant::create(ov::element::i64, {}, {0});
|
||||
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||
auto one_scalar = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {1});
|
||||
auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
|
||||
|
||||
if (ggml_model_decoder.is_static()) {
|
||||
update_indices_k = past_token_len;
|
||||
} else {
|
||||
update_indices_k =
|
||||
std::make_shared<ov::op::v0::Squeeze>(inp_pos, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1}));
|
||||
}
|
||||
update_indices_k = std::make_shared<ov::op::v0::Unsqueeze>(update_indices_k, one);
|
||||
update_indices_k->set_friendly_name("update_indices_k");
|
||||
tensor_map.insert({"update_indices_k", update_indices_k->output(0)});
|
||||
|
||||
auto total_head_size = ggml_model_decoder.get_num_heads_kv() * ggml_model_decoder.get_head_size();
|
||||
auto total_head_size_node = ov::op::v0::Constant::create(ov::element::i64, {1}, {total_head_size});
|
||||
auto total_head_size_scalar = std::make_shared<ov::op::v0::Squeeze>(total_head_size_node, zero);
|
||||
|
||||
// 1D tensor of shape [total_head_size], values starting from 0
|
||||
auto range_row =
|
||||
std::make_shared<ov::op::v4::Range>(zero_scalar, total_head_size_scalar, one_scalar, ov::element::i32);
|
||||
auto range_row_reshaped =
|
||||
std::make_shared<ov::op::v0::Unsqueeze>(range_row, ov::op::v0::Constant::create(ov::element::i64, {2}, {1, 2}));
|
||||
auto row_indices = std::make_shared<ov::op::v3::Broadcast>(
|
||||
range_row_reshaped,
|
||||
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{total_head_size_node, token_len, one}, 0));
|
||||
|
||||
// 1D tensor of shape [token_len], values starting from past_token_len
|
||||
std::shared_ptr<ov::Node> range_col;
|
||||
if (ggml_model_decoder.is_static()) {
|
||||
// aka inp_pos
|
||||
range_col = past_token_len;
|
||||
} else {
|
||||
range_col =
|
||||
std::make_shared<ov::op::v0::Squeeze>(inp_pos, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1}));
|
||||
}
|
||||
auto range_col_reshaped =
|
||||
std::make_shared<ov::op::v0::Unsqueeze>(range_col, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 2}));
|
||||
auto col_indices = std::make_shared<ov::op::v3::Broadcast>(
|
||||
range_col_reshaped,
|
||||
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{total_head_size_node, token_len, one}, 0));
|
||||
|
||||
// Stack row_indices and col_indices along last axis: [total_head_size, token_len, 2]
|
||||
auto indices = std::make_shared<ov::op::v0::Concat>(OutputVector{row_indices, col_indices}, 2);
|
||||
update_indices_v = std::make_shared<ov::op::v1::Reshape>(
|
||||
indices, ov::op::v0::Constant::create(ov::element::i64, {2}, std::vector<int64_t>{-1, 2}), false);
|
||||
update_indices_v->set_friendly_name("update_indices_v");
|
||||
tensor_map.insert({"update_indices_v", update_indices_v->output(0)});
|
||||
}
|
||||
|
||||
// Create common patterns
|
||||
void preprocess(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
|
||||
add_token_len(tensor_map);
|
||||
add_kv_update_indices(tensor_map, ggml_model_decoder);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TranslateSession::TranslateSession(const frontend::InputModel::Ptr& input_model,
|
||||
|
|
@ -118,6 +204,7 @@ std::shared_ptr<Model> TranslateSession::translate_graph(const frontend::InputMo
|
|||
}
|
||||
};
|
||||
|
||||
preprocess(*tensor_map, *ggml_model_decoder);
|
||||
ggml_model_decoder->visit_subgraph(node_visitor);
|
||||
|
||||
for (const auto& name : ggml_model_decoder->get_model_output_names()) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue