Pull out indices creation for kv cache update

This commit is contained in:
Yu, Zijun 2025-07-06 21:59:30 +08:00 committed by Mustafa Cavus
parent bf5414c95e
commit acf358d1ce
3 changed files with 99 additions and 77 deletions

View File

@ -71,6 +71,9 @@ public:
}
Output<Node> get_input(const std::string& name) const override {
if (m_tensor_map->find(name) == m_tensor_map->end()) {
throw std::runtime_error("'" + name + "' not found in tensor map.");
}
return m_tensor_map->at(name);
}

View File

@ -4,19 +4,11 @@
#include <openvino/core/node.hpp>
#include <openvino/core/node_output.hpp>
#include <openvino/core/node_vector.hpp>
#include <openvino/op/add.hpp>
#include <openvino/op/broadcast.hpp>
#include <openvino/op/concat.hpp>
#include <openvino/op/constant.hpp>
#include <openvino/op/convert.hpp>
#include <openvino/op/range.hpp>
#include <openvino/op/reshape.hpp>
#include <openvino/op/scatter_nd_update.hpp>
#include <openvino/op/shape_of.hpp>
#include <openvino/op/slice.hpp>
#include <openvino/op/squeeze.hpp>
#include <openvino/op/transpose.hpp>
#include <openvino/op/unsqueeze.hpp>
#include <vector>
#include "../node_context.hpp"
@ -36,8 +28,13 @@ OutputVector translate_cpy(const NodeContext& context) {
auto src0 = context.get_input(0);
auto src1 = context.get_input(1);
auto token_len = context.get_input("token_len");
auto past_token_len = context.get_input("past_token_len");
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto token_len_scalar = std::make_shared<ov::op::v0::Squeeze>(token_len, zero);
auto past_token_len_scalar = std::make_shared<ov::op::v0::Squeeze>(past_token_len, zero);
src0 = std::make_shared<ov::op::v0::Convert>(src0, context.get_input_type(1));
ov::Output<Node> res;
@ -52,89 +49,24 @@ OutputVector translate_cpy(const NodeContext& context) {
std::vector<size_t> input0_strides = context.get_input_stride(0);
std::vector<size_t> output_strides = context.get_output_stride(0);
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
auto one_scalar = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {1});
if (op_case == 1) {
// Write K to cache_k
auto token_len = get_dimensions(src0.get_node_shared_ptr(), {0});
auto token_len_scalar = std::make_shared<ov::op::v0::Squeeze>(token_len, zero);
std::shared_ptr<ov::Node> indices;
if (context.is_static()) {
indices = past_token_len.get_node_shared_ptr();
} else {
auto past_token_len_scalar = std::make_shared<ov::op::v0::Squeeze>(past_token_len, zero);
auto total_token_len_scalar = std::make_shared<ov::op::v1::Add>(past_token_len_scalar, token_len_scalar);
indices = std::make_shared<ov::op::v4::Range>(past_token_len_scalar,
total_token_len_scalar,
one_scalar,
ov::element::i64);
}
indices = std::make_shared<ov::op::v0::Unsqueeze>(indices, one);
auto indices = context.get_input("update_indices_k");
auto updated = std::make_shared<ov::op::v3::ScatterNDUpdate>(src1, indices, src0);
res = std::make_shared<ov::op::v1::Reshape>(updated, std::make_shared<ov::op::v0::ShapeOf>(src1), false);
} else {
// Write V to cache_v
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
auto zero_scalar = ov::op::v0::Constant::create(ov::element::i64, {}, {0});
int64_t total_head_size = src0_shape[1];
auto total_head_size_node = ov::op::v0::Constant::create(ov::element::i64, {1}, {total_head_size});
auto total_head_size_scalar = std::make_shared<ov::op::v0::Squeeze>(total_head_size_node, zero);
auto token_len = get_dimensions(src0.get_node_shared_ptr(), {2});
auto token_len_scalar = std::make_shared<ov::op::v0::Squeeze>(token_len, zero);
// 1D tensor of shape [total_head_size], values starting from 0
auto range_row =
std::make_shared<ov::op::v4::Range>(zero_scalar, total_head_size_scalar, one_scalar, ov::element::i64);
auto range_row_reshaped =
std::make_shared<ov::op::v0::Unsqueeze>(range_row,
ov::op::v0::Constant::create(ov::element::i64, {2}, {1, 2}));
auto row_indices = std::make_shared<ov::op::v3::Broadcast>(
range_row_reshaped,
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{total_head_size_node, token_len, one}, 0));
// 1D tensor of shape [token_len], values starting from past_token_len
std::shared_ptr<ov::Node> range_col;
if (context.is_static()) {
range_col = past_token_len.get_node_shared_ptr();
} else {
auto past_token_len_scalar = std::make_shared<ov::op::v0::Squeeze>(past_token_len, zero);
auto total_token_len_scalar = std::make_shared<ov::op::v1::Add>(past_token_len_scalar, token_len_scalar);
range_col = std::make_shared<ov::op::v4::Range>(past_token_len_scalar,
total_token_len_scalar,
one_scalar,
ov::element::i64);
}
auto range_col_reshaped =
std::make_shared<ov::op::v0::Unsqueeze>(range_col,
ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 2}));
auto col_indices = std::make_shared<ov::op::v3::Broadcast>(
range_col_reshaped,
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{total_head_size_node, token_len, one}, 0));
// Stack row_indices and col_indices along last axis: [total_head_size, token_len, 2]
auto indices = std::make_shared<ov::op::v0::Concat>(OutputVector{row_indices, col_indices}, 2);
auto indices_final = std::make_shared<ov::op::v1::Reshape>(
indices,
ov::op::v0::Constant::create(ov::element::i64, {2}, std::vector<int64_t>{-1, 2}),
false);
auto flattend_src0 =
std::make_shared<ov::op::v1::Reshape>(src0,
ov::op::v0::Constant::create(element::i64, Shape{1}, {-1}),
false);
int64_t total_head_size = src0_shape[1];
auto reshaped_src1 = std::make_shared<ov::op::v1::Reshape>(
src1,
ov::op::v0::Constant::create(ov::element::i64, {2}, std::vector<int64_t>{total_head_size, -1}),
false);
auto updated = std::make_shared<ov::op::v3::ScatterNDUpdate>(reshaped_src1, indices_final, flattend_src0);
auto indices = context.get_input("update_indices_v");
auto updated = std::make_shared<ov::op::v3::ScatterNDUpdate>(reshaped_src1, indices, flattend_src0);
res = std::make_shared<ov::op::v1::Reshape>(updated, std::make_shared<ov::op::v0::ShapeOf>(src1), false);
}

View File

@ -3,11 +3,20 @@
#include <cstdlib>
#include <map>
#include <memory>
#include <openvino/core/node.hpp>
#include <openvino/op/broadcast.hpp>
#include <openvino/op/concat.hpp>
#include <openvino/op/parameter.hpp>
#include <openvino/op/range.hpp>
#include <openvino/op/reshape.hpp>
#include <openvino/op/result.hpp>
#include <openvino/op/squeeze.hpp>
#include <openvino/op/unsqueeze.hpp>
#include <openvino/pass/constant_folding.hpp>
#include <openvino/pass/make_stateful.hpp>
#include "ggml-openvino/openvino/node_context.hpp"
#include "ggml-openvino/openvino/utils.hpp"
#include "input_model.hpp"
#include "pass/fuse_to_sdpa.hpp"
@ -50,6 +59,83 @@ ov::pass::MakeStateful::ParamResPairs get_kv_param_res_pairs(
}
return pairs;
}
void add_token_len(TensorMap& tensor_map) {
auto inp_tokens = tensor_map.at("inp_tokens").get_node_shared_ptr();
auto token_len = get_dimensions(inp_tokens, {2});
token_len->set_friendly_name("token_len");
tensor_map.insert({"token_len", token_len->output(0)});
}
void add_kv_update_indices(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
// cache_k layout: [S, N, H] (seq, num_heads, head_size)
// cache_v layout: [N, H, S] (num_heads, head_size, seq)
// When writing to cache_v, cache should be reshaped to [N*H, S] and v-curr should be flattened
auto inp_pos = tensor_map.at("inp_pos").get_node_shared_ptr();
auto past_token_len = tensor_map.at("past_token_len").get_node_shared_ptr();
auto token_len = tensor_map.at("token_len").get_node_shared_ptr();
std::shared_ptr<ov::Node> update_indices_k;
std::shared_ptr<ov::Node> update_indices_v;
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto zero_scalar = ov::op::v0::Constant::create(ov::element::i64, {}, {0});
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
auto one_scalar = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {1});
auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
if (ggml_model_decoder.is_static()) {
update_indices_k = past_token_len;
} else {
update_indices_k =
std::make_shared<ov::op::v0::Squeeze>(inp_pos, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1}));
}
update_indices_k = std::make_shared<ov::op::v0::Unsqueeze>(update_indices_k, one);
update_indices_k->set_friendly_name("update_indices_k");
tensor_map.insert({"update_indices_k", update_indices_k->output(0)});
auto total_head_size = ggml_model_decoder.get_num_heads_kv() * ggml_model_decoder.get_head_size();
auto total_head_size_node = ov::op::v0::Constant::create(ov::element::i64, {1}, {total_head_size});
auto total_head_size_scalar = std::make_shared<ov::op::v0::Squeeze>(total_head_size_node, zero);
// 1D tensor of shape [total_head_size], values starting from 0
auto range_row =
std::make_shared<ov::op::v4::Range>(zero_scalar, total_head_size_scalar, one_scalar, ov::element::i32);
auto range_row_reshaped =
std::make_shared<ov::op::v0::Unsqueeze>(range_row, ov::op::v0::Constant::create(ov::element::i64, {2}, {1, 2}));
auto row_indices = std::make_shared<ov::op::v3::Broadcast>(
range_row_reshaped,
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{total_head_size_node, token_len, one}, 0));
// 1D tensor of shape [token_len], values starting from past_token_len
std::shared_ptr<ov::Node> range_col;
if (ggml_model_decoder.is_static()) {
// aka inp_pos
range_col = past_token_len;
} else {
range_col =
std::make_shared<ov::op::v0::Squeeze>(inp_pos, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1}));
}
auto range_col_reshaped =
std::make_shared<ov::op::v0::Unsqueeze>(range_col, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 2}));
auto col_indices = std::make_shared<ov::op::v3::Broadcast>(
range_col_reshaped,
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{total_head_size_node, token_len, one}, 0));
// Stack row_indices and col_indices along last axis: [total_head_size, token_len, 2]
auto indices = std::make_shared<ov::op::v0::Concat>(OutputVector{row_indices, col_indices}, 2);
update_indices_v = std::make_shared<ov::op::v1::Reshape>(
indices, ov::op::v0::Constant::create(ov::element::i64, {2}, std::vector<int64_t>{-1, 2}), false);
update_indices_v->set_friendly_name("update_indices_v");
tensor_map.insert({"update_indices_v", update_indices_v->output(0)});
}
// Create common patterns
void preprocess(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
add_token_len(tensor_map);
add_kv_update_indices(tensor_map, ggml_model_decoder);
}
} // namespace
TranslateSession::TranslateSession(const frontend::InputModel::Ptr& input_model,
@ -118,6 +204,7 @@ std::shared_ptr<Model> TranslateSession::translate_graph(const frontend::InputMo
}
};
preprocess(*tensor_map, *ggml_model_decoder);
ggml_model_decoder->visit_subgraph(node_visitor);
for (const auto& name : ggml_model_decoder->get_model_output_names()) {