llama.cpp/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp

91 lines
3.8 KiB
C++

#include "../node_context.hpp"
#include "../op_table.hpp"
#include "../utils.hpp"
#include <cstdint>
#include <memory>
#include <openvino/op/broadcast.hpp>
#include <openvino/op/concat.hpp>
#include <openvino/op/constant.hpp>
#include <openvino/op/convert.hpp>
#include <openvino/op/reshape.hpp>
#include <openvino/op/scaled_dot_product_attention.hpp>
#include <openvino/op/transpose.hpp>
#include <openvino/op/unsqueeze.hpp>
#include <string>
namespace ov {
namespace frontend {
namespace ggml {
namespace op {
OutputVector translate_flash_attn_ext(const NodeContext & context) {
num_inputs_check(context, 4, 4);
auto q_f32 = context.get_input(0);
auto k = context.get_input(1);
auto v = context.get_input(2);
auto mask = context.get_input(3);
float * params = reinterpret_cast<float *>(context.get_output_op_params());
float scale = params[0];
// float max_bias = params[1];
// float logit_softcap = params[2];
auto q = std::make_shared<ov::op::v0::Convert>(q_f32, ov::element::f16);
auto scale_node = std::make_shared<ov::op::v0::Constant>(ov::element::f16, ov::Shape{}, std::vector<float>{scale});
ov::Output<ov::Node> mask_sliced, res;
std::string mask_name = "KQ_mask_sliced";
if (context.get_input_names()[3].find("swa") != std::string::npos) {
mask_name = "KQ_mask_swa_sliced";
}
if (context.has_input(mask_name)) {
mask_sliced = context.get_input(mask_name);
} else {
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
auto token_len = get_dimensions(q, {2});
mask_sliced = std::make_shared<ov::op::v8::Slice>(mask, zero, token_len, one, two);
}
if (mask_sliced.get_element_type() != ov::element::f16) {
mask_sliced = std::make_shared<ov::op::v0::Convert>(mask_sliced, ov::element::f16);
}
auto tile_kv = [&](int64_t num_heads, int64_t num_heads_kv, int64_t head_size, ov::Output<Node> kv) {
int64_t factor = num_heads / num_heads_kv;
if (factor > 1 && num_heads_kv > 1) {
ov::Output<ov::Node> kv_broadcast_shape, kv_unsqueezed, new_kv_shape;
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {2});
kv_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(kv, unsqueeze_axes);
kv_broadcast_shape = ov::op::v0::Constant::create(
ov::element::i64, {5}, {(int64_t) 1, (int64_t) 1, factor, (int64_t) 1, (int64_t) 1});
new_kv_shape =
ov::op::v0::Constant::create(ov::element::i64, {4}, {(int64_t) 0, num_heads, (int64_t) -1, head_size});
kv = std::make_shared<ov::op::v3::Broadcast>(kv_unsqueezed, kv_broadcast_shape,
ov::op::BroadcastType::BIDIRECTIONAL);
kv = std::make_shared<ov::op::v1::Reshape>(kv, new_kv_shape, true);
}
return kv;
};
auto q_shape = context.get_input_shape(0).to_shape();
auto k_shape = context.get_input_shape(1).to_shape();
k = tile_kv(q_shape[1], k_shape[1], q_shape[3], k);
v = tile_kv(q_shape[1], k_shape[1], q_shape[3], v);
auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v, mask_sliced, scale_node, false);
res = std::make_shared<ov::op::v1::Transpose>(sdpa,
ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3}));
res = std::make_shared<ov::op::v0::Convert>(res, ov::element::f32);
return rename_outputs_with_suffix({res}, context.get_name());
}
} // namespace op
} // namespace ggml
} // namespace frontend
} // namespace ov