kvcachefusion support

This commit is contained in:
cavusmustafa 2025-10-01 14:02:11 -07:00 committed by Mustafa Cavus
parent 973a80fd02
commit c112bc4e73
8 changed files with 145 additions and 55 deletions

View File

@ -316,9 +316,13 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor* src) co
input_shape = ov::PartialShape{1, -1, -1};
}
} else if (name.find("cache_") == 0) {
int layer = extract_layer_from_name(name);
bool is_swa = is_swa_layer(layer);
input_shape = ov::PartialShape{is_swa ? m_context_size_swa : m_context_size, m_num_heads_kv, m_head_size};
if (m_is_static) {
int layer = extract_layer_from_name(name);
bool is_swa = is_swa_layer(layer);
input_shape = ov::PartialShape{is_swa ? m_context_size_swa : m_context_size, m_num_heads_kv, m_head_size};
} else {
input_shape = ov::PartialShape{1, -1, m_num_heads_kv, m_head_size};
}
} else if (const auto* op = get_tensor_used_op(src); op && op->op == GGML_OP_SET_ROWS) {
input_shape = ov::PartialShape{1, 1, m_is_static ? 1 : -1};
} else if (src->op == GGML_OP_VIEW) {

View File

@ -2,6 +2,7 @@
#include <openvino/op/broadcast.hpp>
#include <openvino/op/concat.hpp>
#include <openvino/op/convert.hpp>
#include <openvino/op/gather.hpp>
#include <openvino/op/reshape.hpp>
#include <openvino/op/scaled_dot_product_attention.hpp>
#include <openvino/op/transpose.hpp>
@ -32,7 +33,7 @@ OutputVector translate_flash_attn_ext(const NodeContext& context) {
auto q = std::make_shared<ov::op::v0::Convert>(q_f32, ov::element::f16);
auto scale_node = std::make_shared<ov::op::v0::Constant>(ov::element::f16, ov::Shape{}, std::vector<float>{scale});
ov::Output<ov::Node> mask_sliced;
ov::Output<ov::Node> mask_sliced, res;
std::string mask_name = "KQ_mask_sliced";
if (context.get_input_names()[3].find("swa") != std::string::npos) {
mask_name = "KQ_mask_swa_sliced";
@ -40,33 +41,55 @@ OutputVector translate_flash_attn_ext(const NodeContext& context) {
if (context.has_input(mask_name)) {
mask_sliced = context.get_input(mask_name);
} else {
auto token_len = get_dimensions(q, {1});
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
mask_sliced = std::make_shared<ov::op::v8::Slice>(mask, zero, token_len, one, one);
auto token_len = get_dimensions(q, {2});
auto zero_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {0,0});
auto one_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {1,1});
auto zero_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto two_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
auto axes = ov::op::v0::Constant::create(ov::element::i64, {2}, {1,2});
auto leaf_8 = context.get_input("leaf_8");
auto shape_of_leaf_8 = std::make_shared<ov::op::v3::ShapeOf>(leaf_8);
auto gather_leaf_8 = std::make_shared<ov::op::v8::Gather>(shape_of_leaf_8, two_1d, zero_1d);
auto stop = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{token_len, gather_leaf_8}, 0);
mask_sliced =
std::make_shared<ov::op::v8::Slice>(mask, zero_2d, stop, one_2d, axes);
mask_sliced = std::make_shared<ov::op::v0::Unsqueeze>(mask_sliced, zero_1d);
}
if (mask_sliced.get_element_type() != ov::element::f16) {
mask_sliced = std::make_shared<ov::op::v0::Convert>(mask_sliced, ov::element::f16);
}
auto tile_kv = [](int64_t q_batch, int64_t kv_batch, ov::Output<Node> kv) {
auto tile_kv = [](int64_t q_batch, int64_t kv_batch, ov::Output<Node> kv, bool is_static) {
int64_t factor = q_batch / kv_batch;
if (factor > 1) {
auto q_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{q_batch});
auto kv_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{kv_batch});
auto factor_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{factor});
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1});
auto kv_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(kv, unsqueeze_axes);
ov::Output<ov::Node> kv_broadcast_shape, kv_unsqueezed, new_kv_shape;
if (is_static) {
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1});
kv_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(kv, unsqueeze_axes);
auto kv_last_two_dims = get_dimensions(kv.get_node_shared_ptr(), {1, 2});
kv_broadcast_shape =
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{kv_batch_node, factor_node, kv_last_two_dims}, 0);
new_kv_shape =
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{q_batch_node, kv_last_two_dims}, 0);
} else {
auto one_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {2});
kv_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(kv, unsqueeze_axes);
auto kv_last_two_dims = get_dimensions(kv.get_node_shared_ptr(), {2, 3});
kv_broadcast_shape =
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{one_1d, kv_batch_node, factor_node, kv_last_two_dims}, 0);
new_kv_shape =
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{one_1d, q_batch_node, kv_last_two_dims}, 0);
}
auto kv_last_two_dims = get_dimensions(kv.get_node_shared_ptr(), {1, 2});
auto kv_broadcast_shape =
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{kv_batch_node, factor_node, kv_last_two_dims}, 0);
kv = std::make_shared<ov::op::v3::Broadcast>(kv_unsqueezed, kv_broadcast_shape);
auto new_kv_shape =
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{q_batch_node, kv_last_two_dims}, 0);
kv = std::make_shared<ov::op::v1::Reshape>(kv, new_kv_shape, false);
}
return kv;
@ -74,13 +97,18 @@ OutputVector translate_flash_attn_ext(const NodeContext& context) {
auto q_shape = context.get_input_shape(0).to_shape();
auto k_shape = context.get_input_shape(1).to_shape();
k = tile_kv(q_shape[0], k_shape[0], k);
v = tile_kv(q_shape[0], k_shape[0], v);
k = tile_kv(q_shape[0], k_shape[0], k, context.is_static());
v = tile_kv(q_shape[0], k_shape[0], v, context.is_static());
auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v, mask_sliced, scale_node, false);
auto sdpa_f32 = std::make_shared<ov::op::v0::Convert>(sdpa, ov::element::f32);
auto res = std::make_shared<ov::op::v1::Transpose>(sdpa_f32,
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
if (context.is_static()) {
res = std::make_shared<ov::op::v1::Transpose>(sdpa_f32,
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
} else {
res = std::make_shared<ov::op::v1::Transpose>(sdpa_f32,
ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3}));
}
return rename_outputs_with_suffix({res}, context.get_name());
}

View File

@ -59,13 +59,23 @@ OutputVector translate_mulmat(const NodeContext& context) {
auto Z_last_two_dims = get_dimensions(Z.get_node_shared_ptr(), {1, 2});
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1});
auto Z_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(Z, unsqueeze_axes);
Output<Node> batch_small = A_batch_larger ? B_batch_node : A_batch_node;
Output<Node> batch_large = A_batch_larger ? A_batch_node : B_batch_node;
auto broadcast_shape =
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{batch_small, factor_node, Z_last_two_dims}, 0);
ov::Output<Node> broadcast_shape;
ov::Output<Node> Z_unsqueezed;
if (context.is_static()) {
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1});
Z_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(Z, unsqueeze_axes);
broadcast_shape =
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{batch_small, factor_node, Z_last_two_dims}, 0);
} else {
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {2});
Z_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(Z, unsqueeze_axes);
auto one_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
broadcast_shape =
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{one_1d, batch_small, factor_node, Z_last_two_dims}, 0);
}
auto Z_broadcasted = std::make_shared<ov::op::v3::Broadcast>(Z_unsqueezed, broadcast_shape);
auto new_Z_shape = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{batch_large, Z_last_two_dims}, 0);

View File

@ -25,8 +25,13 @@ OutputVector translate_permute(const NodeContext& context) {
ov::Output<Node> res;
if (op_case == 1) {
res = std::make_shared<ov::op::v1::Transpose>(context.get_input(0),
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
if (context.is_static()) {
res = std::make_shared<ov::op::v1::Transpose>(context.get_input(0),
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
} else {
res = std::make_shared<ov::op::v1::Transpose>(context.get_input(0),
ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3}));
}
} else {
auto src = context.get_input(0);
Output<Node> attention_size;
@ -38,20 +43,23 @@ OutputVector translate_permute(const NodeContext& context) {
attention_size = context.get_input("attention_size_swa");
}
auto src_shape_ = context.get_input_shape(0).to_shape();
std::vector<int64_t> src_shape(src_shape_.begin(), src_shape_.end());
auto src_reshaped = std::make_shared<ov::op::v1::Reshape>(
src,
ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>{-1, src_shape[1], src_shape[2]}),
false);
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
auto src_slice = std::make_shared<ov::op::v8::Slice>(src_reshaped, zero, attention_size, one, zero);
res = std::make_shared<ov::op::v1::Transpose>(src_slice,
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
if (context.is_static()) {
auto src_shape_ = context.get_input_shape(0).to_shape();
std::vector<int64_t> src_shape(src_shape_.begin(), src_shape_.end());
auto src_reshaped = std::make_shared<ov::op::v1::Reshape>(
src,
ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>{-1, src_shape[1], src_shape[2]}),
false);
auto src_slice = std::make_shared<ov::op::v8::Slice>(src_reshaped, zero, attention_size, one, zero);
res = std::make_shared<ov::op::v1::Transpose>(src_slice,
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
} else {
res = std::make_shared<ov::op::v1::Transpose>(src,
ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3}));
}
}
return rename_outputs_with_suffix({res}, context.get_name());
}

View File

@ -84,6 +84,9 @@ OutputVector translate_rope(const NodeContext& context) {
ov::op::v0::Constant::create(ov::element::i64, {1}, {3}));
auto stack = std::make_shared<ov::op::v0::Concat>(OutputVector{first_half, second_half}, 3);
res = std::make_shared<ov::op::v1::Reshape>(stack, std::make_shared<ov::op::v0::ShapeOf>(data_node), false);
if (!(context.is_static())) {
res = std::make_shared<ov::op::v0::Unsqueeze>(res, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));
}
} else if (mode == ROPE_TYPE_NEOX) {
auto data_split = std::make_shared<ov::op::v1::Split>(
data_node, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {2}), 2);

View File

@ -3,10 +3,12 @@
#include <openvino/core/node.hpp>
#include <openvino/core/node_output.hpp>
#include <openvino/frontend/exception.hpp>
#include <openvino/op/concat.hpp>
#include <openvino/op/constant.hpp>
#include <openvino/op/convert.hpp>
#include <openvino/op/gather.hpp>
#include <openvino/op/reshape.hpp>
#include <openvino/op/unsqueeze.hpp>
#include <openvino/op/scatter_update.hpp>
#include <openvino/op/shape_of.hpp>
#include <openvino/op/slice.hpp>
@ -39,17 +41,29 @@ OutputVector translate_set_rows(const NodeContext& context) {
auto dst = context.get_input(context.get_output_name());
auto zero = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {0});
auto dst_reshaped = std::make_shared<ov::op::v1::Reshape>(
dst,
ov::op::v0::Constant::create(ov::element::i64, {2}, {(int64_t) dst_shape[1], (int64_t) dst_shape[2]}),
false);
auto indices_reshaped =
std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1}));
auto data_reshaped = std::make_shared<ov::op::v1::Reshape>(
data, ov::op::v0::Constant::create(ov::element::i64, {2}, {(int64_t) -1, (int64_t) dst_shape[2]}), false);
Output<Node> res;
if (context.is_static()) {
auto dst_reshaped = std::make_shared<ov::op::v1::Reshape>(
dst,
ov::op::v0::Constant::create(ov::element::i64, {2}, {(int64_t) dst_shape[1], (int64_t) dst_shape[2]}),
false);
auto indices_reshaped =
std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1}));
auto data_reshaped = std::make_shared<ov::op::v1::Reshape>(
data, ov::op::v0::Constant::create(ov::element::i64, {2}, {(int64_t) -1, (int64_t) dst_shape[2]}), false);
auto updated = std::make_shared<ov::op::v3::ScatterUpdate>(dst_reshaped, indices_reshaped, data_reshaped, zero);
auto res = std::make_shared<ov::op::v1::Reshape>(updated, std::make_shared<ov::op::v0::ShapeOf>(dst), false);
auto updated = std::make_shared<ov::op::v3::ScatterUpdate>(dst_reshaped, indices_reshaped, data_reshaped, zero);
res = std::make_shared<ov::op::v1::Reshape>(updated, std::make_shared<ov::op::v0::ShapeOf>(dst), false);
} else {
// TODO: Better solution would be to reshape the data into 4D at first place (for stateful model)
if (data.get_partial_shape().rank() + 1 == dst.get_partial_shape().rank()) {
data = std::make_shared<ov::op::v0::Unsqueeze>(data, zero);
}
int concat_axis = 1;
if (context.is_static())
concat_axis = 0;
res = std::make_shared<ov::op::v0::Concat>(OutputVector{dst, data}, concat_axis);
}
return rename_outputs_with_suffix({res}, context.get_name());
}

View File

@ -7,8 +7,10 @@
#include <openvino/op/concat.hpp>
#include <openvino/op/constant.hpp>
#include <openvino/op/convert.hpp>
#include <openvino/op/gather.hpp>
#include <openvino/op/matmul.hpp>
#include <openvino/op/multiply.hpp>
#include <openvino/op/unsqueeze.hpp>
#include <openvino/op/slice.hpp>
#include <openvino/op/softmax.hpp>
#include <vector>
@ -57,9 +59,20 @@ OutputVector translate_soft_max(const NodeContext& context) {
} else {
auto token_len = get_dimensions(input_node, {1});
auto mask_node = context.get_input(1);
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
mask_node_sliced = std::make_shared<ov::op::v8::Slice>(mask_node, zero, token_len, one, one);
auto zero_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {0,0});
auto one_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {1,1});
auto zero_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto two_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
auto axes = ov::op::v0::Constant::create(ov::element::i64, {2}, {1,2});
auto leaf_8 = context.get_input("leaf_8");
auto shape_of_leaf_8 = std::make_shared<ov::op::v3::ShapeOf>(leaf_8);
auto gather_leaf_8 = std::make_shared<ov::op::v8::Gather>(shape_of_leaf_8, two_1d, zero_1d);
auto stop = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{token_len, gather_leaf_8}, 0);
mask_node_sliced =
std::make_shared<ov::op::v8::Slice>(mask_node, zero_2d, stop, one_2d, axes);
if (!(context.is_static())) {
mask_node_sliced = std::make_shared<ov::op::v0::Unsqueeze>(mask_node_sliced, zero_1d);
}
}
if (mask_node_sliced.get_element_type() != context.get_output_type(0)) {

View File

@ -11,6 +11,7 @@
#include <openvino/op/convert.hpp>
#include <openvino/op/cos.hpp>
#include <openvino/op/divide.hpp>
#include <openvino/op/gather.hpp>
#include <openvino/op/multiply.hpp>
#include <openvino/op/parameter.hpp>
#include <openvino/op/range.hpp>
@ -87,9 +88,18 @@ void add_sliced_mask(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
if (is_static) {
mask_sliced = mask;
} else {
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
mask_sliced = std::make_shared<ov::op::v8::Slice>(mask, zero, token_len, one, one);
auto zero_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {0,0});
auto one_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {1,1});
auto zero_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto two_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
auto axes = ov::op::v0::Constant::create(ov::element::i64, {2}, {1,2});
auto leaf_8 = tensor_map.at("leaf_8").get_node_shared_ptr();
auto shape_of_leaf_8 = std::make_shared<ov::op::v3::ShapeOf>(leaf_8);
auto gather_leaf_8 = std::make_shared<ov::op::v8::Gather>(shape_of_leaf_8, two_1d, zero_1d);
auto stop = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{token_len, gather_leaf_8}, 0);
mask_sliced =
std::make_shared<ov::op::v8::Slice>(mask, zero_2d, stop, one_2d, axes);
mask_sliced = std::make_shared<ov::op::v0::Unsqueeze>(mask_sliced, zero_1d);
mask_sliced = std::make_shared<ov::op::v0::Convert>(mask_sliced, ov::element::f16);
mask_sliced->set_friendly_name(sliced_name);
}