Fuse to SDPA

This commit is contained in:
Yu, Zijun 2025-07-03 13:22:39 +08:00 committed by Mustafa Cavus
parent 73ee84fffe
commit ebc4fc9f95
12 changed files with 189 additions and 93 deletions

View File

@ -26,27 +26,36 @@
#include "ggml-backend-impl.h"
#include "ggml-backend.h"
GgmlOvDecoder::GgmlOvDecoder(struct ggml_tensor* node, struct ggml_cgraph* m_cgraph, bool is_static,
GgmlOvDecoder::GgmlOvDecoder(struct ggml_tensor* node, struct ggml_cgraph* cgraph, bool is_static, bool is_first_token,
int context_size, int num_heads, int num_heads_kv, int head_size) :
GgmlOvDecoder::GgmlOvDecoder(node, cgraph, is_static, is_first_token) {
m_context_size = context_size;
m_num_heads = num_heads;
m_num_heads_kv = num_heads_kv;
m_head_size = head_size;
}
GgmlOvDecoder::GgmlOvDecoder(struct ggml_tensor* node, struct ggml_cgraph* cgraph, bool is_static,
bool is_first_token) :
m_cgraph(m_cgraph),
m_cgraph(cgraph),
m_node(node),
m_op_name(m_node ? std::string(m_node->name) : "NONE_OP"),
m_is_static(is_static),
m_is_first_token(is_first_token) {
// TODO avoid static
static std::map<std::string, std::shared_ptr<ov::Node>> model_weights;
if (m_node) {
set_input_output(m_node);
} else {
static bool printed = false;
if (!printed && getenv("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS")) {
print_tensor_address_map(m_cgraph);
print_tensor_address_map(cgraph);
printed = true;
}
if (getenv("GGML_OPENVINO_DUMP_CGRAPH")) {
std::string filename = "cgraph.txt";
dump_cgraph(m_cgraph, filename);
dump_cgraph(cgraph, filename);
}
set_llm_params();
@ -57,8 +66,8 @@ GgmlOvDecoder::GgmlOvDecoder(struct ggml_tensor* node, struct ggml_cgraph* m_cgr
weight_created = true;
}
for (int node_n = 0; node_n < m_cgraph->n_nodes; node_n++) {
auto* cur_node = m_cgraph->nodes[node_n];
for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) {
auto* cur_node = cgraph->nodes[node_n];
m_nodes.push_back(cur_node);
set_input_output(cur_node);
}
@ -195,7 +204,7 @@ void GgmlOvDecoder::set_llm_params() {
auto* node = m_cgraph->nodes[i];
if (node->op == GGML_OP_VIEW && std::string(node->name) == "cache_k_l0 (view)") {
auto* cache_k = node->src[0];
m_max_token_len = cache_k->ne[1];
m_context_size = cache_k->ne[1];
} else if (node->op == GGML_OP_ROPE && std::string(node->name) == "Qcur-0") {
m_head_size = node->ne[0];
m_num_heads = node->ne[1];
@ -210,30 +219,30 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor* src) co
if (std::string(src->name) == "inp_tokens" || std::string(src->name) == "inp_pos") {
if (m_is_static) {
if (m_is_first_token) {
input_shape = ov::PartialShape{ 1, 1, m_max_token_len };
input_shape = ov::PartialShape{1, 1, m_context_size};
} else {
input_shape = ov::PartialShape{ 1, 1, 1 };
input_shape = ov::PartialShape{1, 1, 1};
}
} else {
input_shape = ov::PartialShape{ 1, 1, ov::Dimension(1, m_max_token_len) };
input_shape = ov::PartialShape{1, 1, ov::Dimension(1, m_context_size)};
}
} else if (std::string(src->name) == "KQ_mask") {
if (m_is_static) {
if (m_is_first_token) {
input_shape = ov::PartialShape{ 1, m_max_token_len, m_max_token_len };
input_shape = ov::PartialShape{1, m_context_size, m_context_size};
} else {
input_shape = ov::PartialShape{ 1, 1, m_max_token_len };
input_shape = ov::PartialShape{1, 1, m_context_size};
}
} else {
auto max_mask_size = GGML_PAD(m_max_token_len, GGML_KQ_MASK_PAD);
input_shape = ov::PartialShape{ 1, ov::Dimension(1, max_mask_size), ov::Dimension(1, max_mask_size) };
auto max_mask_size = GGML_PAD(m_context_size, GGML_KQ_MASK_PAD);
input_shape = ov::PartialShape{1, ov::Dimension(1, max_mask_size), ov::Dimension(1, max_mask_size)};
}
} else if (std::string(src->name).find("cache_k") == 0) {
input_shape = ov::PartialShape{ m_max_token_len, m_num_heads_kv, m_head_size };
input_shape = ov::PartialShape{m_context_size, m_num_heads_kv, m_head_size};
} else if (std::string(src->name).find("cache_v") == 0) {
input_shape = ov::PartialShape{ m_num_heads_kv, m_head_size, m_max_token_len };
input_shape = ov::PartialShape{m_num_heads_kv, m_head_size, m_context_size};
} else {
input_shape = ov::PartialShape{ get_shape(src) };
input_shape = ov::PartialShape{get_shape(src)};
}
return input_shape;
}
@ -557,7 +566,8 @@ int32_t* GgmlOvDecoder::get_output_op_params(const std::string& name) const {
void GgmlOvDecoder::visit_subgraph(std::function<void(std::shared_ptr<GgmlDecoder>)> node_visitor) const {
for (const auto& node : m_nodes) {
auto decoder = std::make_shared<GgmlOvDecoder>(node, m_cgraph, m_is_static, m_is_first_token);
auto decoder = std::make_shared<GgmlOvDecoder>(node, m_cgraph, m_is_static, m_is_first_token, m_context_size,
m_num_heads, m_num_heads_kv, m_head_size);
node_visitor(decoder);
}
}

View File

@ -11,9 +11,9 @@
class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {
public:
using ov::frontend::ggml::GgmlDecoder::GgmlDecoder;
GgmlOvDecoder(struct ggml_tensor* node, struct ggml_cgraph* cgraph, bool is_static, bool is_first_token);
GgmlOvDecoder(struct ggml_tensor* node, struct ggml_cgraph* cgraph, bool is_static, bool is_first_token,
int context_size, int num_heads, int num_heads_kv, int head_size);
virtual ov::Any get_attribute(const std::string& name) const override {
return nullptr;
@ -90,7 +90,7 @@ public:
return m_model_output_names;
}
virtual int get_max_token_len() const override { return m_max_token_len; }
virtual int get_context_size() const override { return m_context_size; }
virtual int get_num_heads() const override { return m_num_heads; }
@ -114,7 +114,7 @@ private:
static std::vector<size_t> get_stride(const ggml_tensor* tensor);
static ov::element::Type get_ov_type(const ggml_tensor* tensor);
// set max_token_len, num_heads, etc
// set context_size, num_heads, etc
void set_llm_params();
static std::shared_ptr<ov::Node> create_weight_node(ggml_tensor* tensor);
@ -136,7 +136,7 @@ private:
std::map<std::string, std::shared_ptr<ov::Tensor>> m_model_extra_input_values;
std::map<std::string, std::shared_ptr<ov::Node>> m_model_weights;
std::vector<std::string> m_model_output_names;
int m_max_token_len;
int m_context_size;
int m_num_heads;
int m_num_heads_kv;
int m_head_size;

View File

@ -65,7 +65,7 @@ public:
virtual bool is_static() const = 0;
virtual bool is_first_token() const = 0;
virtual int get_max_token_len() const = 0;
virtual int get_context_size() const = 0;
};
} // namespace ggml

View File

@ -91,11 +91,16 @@ public:
bool is_first_token() const {
return m_decoder->is_first_token();
}
int get_max_token_len() const {
return m_decoder->get_max_token_len();
}
private:
int get_num_heads() const { return m_decoder->get_num_heads(); }
int get_num_heads_kv() const { return m_decoder->get_num_heads_kv(); }
int get_head_size() const { return m_decoder->get_head_size(); }
int get_context_size() const { return m_decoder->get_context_size(); }
private:
std::shared_ptr<GgmlDecoder> m_decoder;
std::shared_ptr<TensorMap>& m_tensor_map;
TranslateSession* m_translate_session;

View File

@ -38,9 +38,8 @@ OutputVector translate_mulmat(const NodeContext& context) {
ov::Output<ov::Node> B = context.get_input(0);
ov::Output<ov::Node> A = std::make_shared<ov::op::v0::Convert>(context.get_input(1), context.get_input_type(0));
auto src0_shape = context.get_input_shape(0).to_shape();
int64_t num_heads = context.get_input_shape(1).to_shape()[0];
int64_t num_heads_kv = src0_shape[0];
int64_t num_heads = context.get_num_heads();
int64_t num_heads_kv = context.get_num_heads_kv();
int64_t kv_num_heads_factor = num_heads / num_heads_kv;
if (kv_num_heads_factor > 1) {
auto num_heads_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{num_heads});

View File

@ -27,7 +27,7 @@ OutputVector translate_permute(const NodeContext& context) {
if (op_case == 1) {
auto perm = argsort_descend(context.get_output_stride(0));
res = std::make_shared<ov::op::v1::Transpose>(context.get_input(0),
ov::op::v0::Constant::create(ov::element::i64, { 3 }, perm));
ov::op::v0::Constant::create(ov::element::i64, {3}, perm));
} else {
auto src = context.get_input(0);
auto attention_size = context.get_input("attention_size");
@ -51,19 +51,16 @@ OutputVector translate_permute(const NodeContext& context) {
false);
}
auto slice_start = ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>(3, 0));
auto slice_step = ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>(3, 1));
std::shared_ptr<ov::Node> slice_end;
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
std::shared_ptr<ov::Node> slice_axis;
if (op_case == 2) {
slice_end = std::make_shared<ov::op::v0::Concat>(
ov::OutputVector{attention_size, ov::op::v0::Constant::create(ov::element::i64, {2}, {src_shape[1], src_shape[2]})},
0);
slice_axis = zero;
} else {
slice_end = std::make_shared<ov::op::v0::Concat>(
ov::OutputVector{ov::op::v0::Constant::create(ov::element::i64, {2}, {src_shape[1], src_shape[0]}), attention_size},
0);
slice_axis = two;
}
auto src_slice = std::make_shared<ov::op::v8::Slice>(src_reshaped, slice_start, slice_end, slice_step);
auto src_slice = std::make_shared<ov::op::v8::Slice>(src_reshaped, zero, attention_size, one, slice_axis);
if (op_case == 2) {
res = std::make_shared<ov::op::v1::Transpose>(src_slice, ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
@ -71,7 +68,7 @@ OutputVector translate_permute(const NodeContext& context) {
res = src_slice;
}
}
return rename_outputs_with_suffix({ res }, context.get_name());
return rename_outputs_with_suffix({res}, context.get_name());
}
} // namespace op

View File

@ -1,3 +1,5 @@
#include <climits>
#include <cstdint>
#include <memory>
#include <openvino/core/node.hpp>
#include <openvino/core/node_output.hpp>
@ -5,6 +7,7 @@
#include <openvino/op/concat.hpp>
#include <openvino/op/constant.hpp>
#include <openvino/op/convert.hpp>
#include <openvino/op/matmul.hpp>
#include <openvino/op/multiply.hpp>
#include <openvino/op/slice.hpp>
#include <openvino/op/softmax.hpp>
@ -22,62 +25,61 @@ namespace op {
OutputVector translate_soft_max(const NodeContext& context) {
num_inputs_check(context, 1, 2);
auto input_node = context.get_input(0);
auto input_node = context.get_input(0).get_node_shared_ptr();
ov::Output<Node> res;
float scale = 1.0f;
float max_bias = 0.0f;
auto * op_params = context.get_output_op_params(0);
memcpy(&scale, (float*)op_params + 0, sizeof(float));
memcpy(&max_bias, (float*)op_params + 1, sizeof(float));
auto* op_params = context.get_output_op_params(0);
memcpy(&scale, (float*) op_params + 0, sizeof(float));
memcpy(&max_bias, (float*) op_params + 1, sizeof(float));
const uint32_t h = context.get_head_size();
// const uint32_t n_head = context.get_input_shape(0)[0].get_length();
// const uint32_t n_head_log2 = 1u << (uint32_t)floor(log2(n_head));
const uint32_t n_head = context.get_input_shape(0)[0].get_length();
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
// const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
// const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
// const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1)
// : 1.0f;
const float slope = 1.0;
const float m0 = powf(2.0f, -(max_bias) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
const float slope =
(max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1) : 1.0f;
std::shared_ptr<ov::Node> scaled_input;
if (scale != 1.0f) {
auto scale_node =
std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{}, std::vector<float>{scale});
input_node = std::make_shared<ov::op::v1::Multiply>(input_node, scale_node);
scaled_input = std::make_shared<ov::op::v1::Multiply>(input_node, scale_node);
}
if (context.get_input_size() == 2) {
// Calculate mask then softmax
auto mask_node = context.get_input(1);
ov::element::Type mask_type = context.get_input_type(1);
if (mask_type == ov::element::f16) {
// Convert f16 to f32
mask_node = std::make_shared<ov::op::v0::Convert>(mask_node, ov::element::f32);
}
auto mask_node = context.get_input(1);
// Stride slice mask node
Output<Node> slice_start = ov::op::v0::Constant::create(ov::element::i64, Shape{3}, {0, 0, 0});
auto one = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {1});
auto token_len = get_dimensions(input_node.get_node_shared_ptr(), {1});
auto total_token_len = get_dimensions(mask_node.get_node_shared_ptr(), {2});
auto slice_end = std::make_shared<ov::op::v0::Concat>(ov::NodeVector{one, token_len, total_token_len}, 0);
Output<Node> slice_stride = ov::op::v0::Constant::create(ov::element::i64, Shape{3}, {1, 1, 1});
auto mask_node_sliced = std::make_shared<ov::op::v8::Slice>(mask_node, slice_start, slice_end, slice_stride);
// Use Q-cur to retrieve the token length, so that the translation of SOFT_MAX
// does not depend on the result of the QK MatMul, so that QK matmul + softmax + qkv matmul
// can be fused into SDPA.
if (input_node->get_type_info() != ov::op::v0::Convert::get_type_info_static()) {
throw std::runtime_error("Input of SOFT_MAX should be MatMul of qk followed by a Convert");
}
auto qk = input_node->get_input_node_shared_ptr(0);
if (qk->get_type_info() != ov::op::v0::MatMul::get_type_info_static()) {
throw std::runtime_error("Input of SOFT_MAX should be MatMul of qk followed by a Convert");
}
auto token_len = get_dimensions(qk->get_input_node_shared_ptr(0), {1});
// slope * mask
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
auto mask_node_sliced = std::make_shared<ov::op::v8::Slice>(mask_node, zero, token_len, one, one);
Output<Node> slope_mask;
if (slope != 1.0f) {
auto slope_node =
std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{}, std::vector<float>{slope});
auto slope_mask_node = std::make_shared<ov::op::v1::Multiply>(mask_node_sliced, slope_node);
// input + slope * mask
auto input_slope_mask_node = std::make_shared<ov::op::v1::Add>(input_node, slope_mask_node);
// Calculate softmax
res = std::make_shared<ov::op::v8::Softmax>(input_slope_mask_node, 2);
} else {
// Directly softmax
res = std::make_shared<ov::op::v8::Softmax>(input_node, 0);
slope_mask = std::make_shared<ov::op::v1::Multiply>(mask_node_sliced, slope_node);
throw std::runtime_error("Slope != 1.0f in softmax has not been tested, verify it before use.");
}
slope_mask = mask_node_sliced;
auto input_slope_mask_node = std::make_shared<ov::op::v1::Add>(scaled_input, slope_mask);
res = std::make_shared<ov::op::v8::Softmax>(input_slope_mask_node, 2);
return rename_outputs_with_suffix({res}, context.get_name());
}

View File

@ -0,0 +1,61 @@
#include "fuse_to_sdpa.hpp"
#include <openvino/core/graph_util.hpp>
#include <openvino/core/rt_info.hpp>
#include <openvino/op/add.hpp>
#include <openvino/op/convert.hpp>
#include <openvino/op/matmul.hpp>
#include <openvino/op/multiply.hpp>
#include <openvino/op/scaled_dot_product_attention.hpp>
#include <openvino/op/softmax.hpp>
#include <openvino/op/transpose.hpp>
#include <openvino/pass/pattern/op/label.hpp>
#include <openvino/pass/pattern/op/pattern.hpp>
#include <openvino/pass/pattern/op/wrap_type.hpp>
namespace ov {
namespace frontend {
namespace ggml {
namespace pass {
FuseToSDPA::FuseToSDPA() {
const auto m_k = ov::pass::pattern::any_input();
const auto m_q = ov::pass::pattern::any_input();
const auto m_qk = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_q, m_k});
const auto m_qk_f32 = ov::pass::pattern::wrap_type<ov::op::v0::Convert>({m_qk});
const auto m_scale = ov::pass::pattern::any_input();
const auto m_scaled_qk = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({m_qk_f32, m_scale});
const auto m_mask = ov::pass::pattern::any_input();
const auto m_masked_qk = ov::pass::pattern::wrap_type<ov::op::v1::Add>({m_scaled_qk, m_mask});
const auto m_softmax_qk = ov::pass::pattern::wrap_type<ov::op::v8::Softmax>({m_masked_qk});
const auto m_softmax_qk_f16 = ov::pass::pattern::wrap_type<ov::op::v0::Convert>({m_softmax_qk});
const auto m_v = ov::pass::pattern::any_input();
const auto m_qkv = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_softmax_qk_f16, m_v});
const auto callback = [=](ov::pass::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();
auto k = pattern_to_output[m_k];
auto q = pattern_to_output[m_q];
auto v = pattern_to_output[m_v];
auto mask = pattern_to_output[m_mask];
auto scale = pattern_to_output[m_scale];
auto v_trans =
register_new_node<ov::op::v1::Transpose>(v, ov::op::v0::Constant::create(ov::element::i64, {3}, {0, 2, 1}));
auto mask_f16 = register_new_node<ov::op::v0::Convert>(mask, ov::element::f16);
auto scale_f16 = register_new_node<ov::op::v0::Convert>(scale, ov::element::f16);
auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v_trans, mask_f16, scale_f16, false);
ov::replace_node(m.get_match_root(), sdpa);
ov::copy_runtime_info(m.get_matched_nodes(), sdpa);
return true;
};
register_matcher(std::make_shared<ov::pass::pattern::Matcher>(m_qkv, "ov::frontend::ggml::pass::FuseToSDPA"),
callback);
}
} // namespace pass
} // namespace ggml
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,17 @@
#include "openvino/pass/matcher_pass.hpp"
namespace ov {
namespace frontend {
namespace ggml {
namespace pass {
class FuseToSDPA : public ov::pass::MatcherPass {
public:
OPENVINO_MATCHER_PASS_RTTI("ov::frontend::ggml::pass::FuseToSDPA")
FuseToSDPA();
};
} // namespace pass
} // namespace ggml
} // namespace frontend
} // namespace ov

View File

@ -9,6 +9,7 @@
#include <openvino/pass/make_stateful.hpp>
#include "input_model.hpp"
#include "pass/fuse_to_sdpa.hpp"
namespace ov {
namespace frontend {
@ -145,6 +146,8 @@ void TranslateSession::apply_transformations(const std::shared_ptr<Model>& model
const auto kv_param_res_names = ggml_model_decoder->get_kv_param_res_names();
const auto kv_param_res_pairs = get_kv_param_res_pairs(model, kv_param_res_names);
manager.register_pass<ov::pass::MakeStateful>(kv_param_res_pairs);
manager.register_pass<pass::FuseToSDPA>();
}
manager.run_passes(model);

View File

@ -65,7 +65,7 @@ template <typename T>
OutputVector translate_1to1_match_2_inputs(const NodeContext& context) {
num_inputs_check(context, 2, 2);
auto res = std::make_shared<T>(context.get_input(0), context.get_input(1));
return rename_outputs_with_suffix({ res }, context.get_name());
return rename_outputs_with_suffix({res}, context.get_name());
}
} // namespace op

View File

@ -17,6 +17,7 @@
#include <openvino/runtime/compiled_model.hpp>
#include <openvino/runtime/infer_request.hpp>
#include <openvino/runtime/intel_npu/properties.hpp>
#include <openvino/runtime/properties.hpp>
#include <openvino/runtime/tensor.hpp>
#include <unordered_map>
#include <vector>
@ -88,6 +89,7 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c
if (cache_dir && !is_static) {
core.set_property(ov::cache_dir(cache_dir));
}
// core.set_property(ov::enable_profiling(true));
static std::unordered_map<struct ggml_cgraph*, std::shared_ptr<ov::InferRequest>> infer_request_cache;
static std::unordered_map<struct ggml_cgraph*, std::vector<std::string>> ov_input_names_cache;
@ -256,10 +258,10 @@ ov::Tensor get_ov_input_tensor(std::shared_ptr<GgmlOvDecoder> ggml_decoder, cons
} else {
if (param_name == "inp_tokens" || param_name == "inp_pos") {
if (is_first_token) {
size_t max_token_len = ggml_decoder->get_max_token_len();
size_t context_size = ggml_decoder->get_context_size();
const auto* input_tensor_ggml = ggml_decoder->get_input_ggml_tensor(param_name);
std::vector<int32_t> padded_data = pad_input<int32_t>(input_tensor_ggml, 1, max_token_len, 0);
input_tensor = ov::Tensor(ov::element::i32, ov::Shape{ 1, 1, max_token_len });
std::vector<int32_t> padded_data = pad_input<int32_t>(input_tensor_ggml, 1, context_size, 0);
input_tensor = ov::Tensor(ov::element::i32, ov::Shape{1, 1, context_size});
auto* data_ptr = input_tensor.data<int32_t>();
std::copy(padded_data.begin(), padded_data.end(), data_ptr);
} else {
@ -267,18 +269,18 @@ ov::Tensor get_ov_input_tensor(std::shared_ptr<GgmlOvDecoder> ggml_decoder, cons
}
} else if (param_name == "KQ_mask") {
size_t max_token_len = ggml_decoder->get_max_token_len();
size_t context_size = ggml_decoder->get_context_size();
const auto* input_tensor_ggml = ggml_decoder->get_input_ggml_tensor(param_name);
if (is_first_token) {
std::vector<float> padded_data =
pad_input<float>(input_tensor_ggml, max_token_len, max_token_len, -INFINITY);
set_zero_diagonal(padded_data, max_token_len);
input_tensor = ov::Tensor(ov::element::f32, ov::Shape{ 1, max_token_len, max_token_len });
pad_input<float>(input_tensor_ggml, context_size, context_size, -INFINITY);
set_zero_diagonal(padded_data, context_size);
input_tensor = ov::Tensor(ov::element::f32, ov::Shape{1, context_size, context_size});
auto* data_ptr = input_tensor.data<float>();
std::copy(padded_data.begin(), padded_data.end(), data_ptr);
} else {
std::vector<float> padded_data = pad_input<float>(input_tensor_ggml, 1, max_token_len, -INFINITY);
input_tensor = ov::Tensor(ov::element::f32, ov::Shape{ 1, 1, max_token_len });
std::vector<float> padded_data = pad_input<float>(input_tensor_ggml, 1, context_size, -INFINITY);
input_tensor = ov::Tensor(ov::element::f32, ov::Shape{1, 1, context_size});
auto* data_ptr = input_tensor.data<float>();
std::copy(padded_data.begin(), padded_data.end(), data_ptr);
}