Change openvino device_type to GPU; Enable flash_attn

This commit is contained in:
Yu, Zijun 2025-09-05 16:41:15 +08:00 committed by Mustafa Cavus
parent 65e1b1af6d
commit 56d596775d
6 changed files with 104 additions and 30 deletions

View File

@ -299,6 +299,13 @@ void GgmlOvDecoder::add_extra_inputs() {
attention_size = mask->ne[0];
break;
}
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
auto* mask = node->src[3];
if (std::string(mask->name).find("KQ_mask") != 0) {
throw std::runtime_error("Unexpected flash attention node: " + std::string(mask->name));
}
attention_size = mask->ne[0];
}
}
{

View File

@ -173,14 +173,15 @@ static void ggml_backend_openvino_device_get_memory(ggml_backend_dev_t dev, size
GGML_ASSERT(free != nullptr);
GGML_ASSERT(total != nullptr);
ggml_backend_openvino_device_context * ctx = (ggml_backend_openvino_device_context *)dev->context;
// Placeholder
GGML_ASSERT(ctx->device >= 0);
// ggml_openvino_set_device(ctx->device);
*total = 1;
*free = 1;
}
static enum ggml_backend_dev_type ggml_backend_openvino_device_get_type(ggml_backend_dev_t dev) {
GGML_UNUSED(dev);
return GGML_BACKEND_DEVICE_TYPE_ACCEL;
return GGML_BACKEND_DEVICE_TYPE_GPU;
}
static void ggml_backend_openvino_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
@ -293,7 +294,7 @@ static bool is_op_unsupported_case(const ggml_tensor* op) {
GGML_LOG_WARN("OpenVINO backend does not support ROPE with mode %d\n", mode);
return true;
}
if (n_dims != op->src[0]->ne[0]) {
if (n_dims != 0.0f && n_dims != op->src[0]->ne[0]) {
GGML_LOG_WARN("OpenVINO backend does not support ROPE with n_dims %d != src[0]->ne[0] %ld\n",
n_dims,
op->src[0]->ne[0]);
@ -305,7 +306,7 @@ static bool is_op_unsupported_case(const ggml_tensor* op) {
}
float freq_scale;
memcpy(&freq_scale, op_params + 6, sizeof(float));
if (freq_scale != 1.0f) {
if (freq_scale != 0.0f && freq_scale != 1.0f) {
GGML_LOG_WARN("OpenVINO backend does not support ROPE with freq_scale %f != 1.0f\n", freq_scale);
return true;
}

View File

@ -1,6 +1,12 @@
#include <memory>
#include <openvino/op/broadcast.hpp>
#include <openvino/op/concat.hpp>
#include <openvino/op/convert.hpp>
#include <openvino/op/reshape.hpp>
#include <openvino/op/scaled_dot_product_attention.hpp>
#include <openvino/op/transpose.hpp>
#include <openvino/op/unsqueeze.hpp>
#include "../node_context.hpp"
#include "../op_table.hpp"
#include "../utils.hpp"
@ -24,9 +30,53 @@ OutputVector translate_flash_attn_ext(const NodeContext& context) {
auto q = std::make_shared<ov::op::v0::Convert>(q_f32, ov::element::f16);
auto scale_node = std::make_shared<ov::op::v0::Constant>(ov::element::f16, ov::Shape{}, std::vector<float>{scale});
auto res = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v , mask, scale_node, false);
auto res_f32 = std::make_shared<ov::op::v0::Convert>(res, ov::element::f32);
return rename_outputs_with_suffix({res_f32}, context.get_name());
ov::Output<ov::Node> mask_sliced;
if (context.has_input("KQ_mask_sliced")) {
mask_sliced = context.get_input("KQ_mask_sliced");
} else {
auto token_len = get_dimensions(q, {1});
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
mask_sliced = std::make_shared<ov::op::v8::Slice>(mask, zero, token_len, one, one);
}
if (mask_sliced.get_element_type() != ov::element::f16) {
mask_sliced = std::make_shared<ov::op::v0::Convert>(mask_sliced, ov::element::f16);
}
auto tile_kv = [](int64_t q_batch, int64_t kv_batch, ov::Output<Node> kv) {
int64_t factor = q_batch / kv_batch;
if (factor > 1) {
auto q_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{q_batch});
auto kv_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{kv_batch});
auto factor_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{factor});
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1});
auto kv_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(kv, unsqueeze_axes);
auto kv_last_two_dims = get_dimensions(kv.get_node_shared_ptr(), {1, 2});
auto kv_broadcast_shape =
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{kv_batch_node, factor_node, kv_last_two_dims}, 0);
kv = std::make_shared<ov::op::v3::Broadcast>(kv_unsqueezed, kv_broadcast_shape);
auto new_kv_shape =
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{q_batch_node, kv_last_two_dims}, 0);
kv = std::make_shared<ov::op::v1::Reshape>(kv, new_kv_shape, false);
}
return kv;
};
auto q_shape = context.get_input_shape(0).to_shape();
auto k_shape = context.get_input_shape(1).to_shape();
k = tile_kv(q_shape[0], k_shape[0], k);
v = tile_kv(q_shape[0], k_shape[0], v);
auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v, mask_sliced, scale_node, false);
auto sdpa_f32 = std::make_shared<ov::op::v0::Convert>(sdpa, ov::element::f32);
auto res = std::make_shared<ov::op::v1::Transpose>(sdpa_f32,
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
return rename_outputs_with_suffix({res}, context.get_name());
}
} // namespace op

View File

@ -62,7 +62,7 @@ OutputVector translate_mulmat(const NodeContext& context) {
auto B_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{B_batch});
auto factor_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{factor});
auto Z_last_two_dim = get_dimensions(Z.get_node_shared_ptr(), {1, 2});
auto Z_last_two_dims = get_dimensions(Z.get_node_shared_ptr(), {1, 2});
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1});
auto Z_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(Z, unsqueeze_axes);
@ -70,26 +70,26 @@ OutputVector translate_mulmat(const NodeContext& context) {
Output<Node> batch_small = A_batch_larger ? B_batch_node : A_batch_node;
Output<Node> batch_large = A_batch_larger ? A_batch_node : B_batch_node;
auto broadcast_shape =
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{batch_small, factor_node, Z_last_two_dim}, 0);
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{batch_small, factor_node, Z_last_two_dims}, 0);
auto Z_broadcasted = std::make_shared<ov::op::v3::Broadcast>(Z_unsqueezed, broadcast_shape);
auto new_Z_shape = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{batch_large, Z_last_two_dim}, 0);
auto new_Z_shape = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{batch_large, Z_last_two_dims}, 0);
Z = std::make_shared<ov::op::v1::Reshape>(Z_broadcasted, new_Z_shape, false);
}
if (A_batch_larger) {
B = Z;
} else {
A = Z;
}
}
if (A_batch_larger) {
B = Z;
} else {
A = Z;
}
if (convert_out_type) {
auto result_lp = std::make_shared<ov::op::v0::MatMul>(A, B, false, transpose_b);
res = std::make_shared<ov::op::v0::Convert>(result_lp, context.get_output_type(0));
} else {
res = std::make_shared<ov::op::v0::MatMul>(A, B, false, transpose_b);
}
if (convert_out_type) {
auto result_lp = std::make_shared<ov::op::v0::MatMul>(A, B, false, transpose_b);
res = std::make_shared<ov::op::v0::Convert>(result_lp, context.get_output_type(0));
} else {
res = std::make_shared<ov::op::v0::MatMul>(A, B, false, transpose_b);
}
return rename_outputs_with_suffix({res}, context.get_name());
return rename_outputs_with_suffix({res}, context.get_name());
}
} // namespace op

View File

@ -51,14 +51,18 @@ OutputVector translate_soft_max(const NodeContext& context) {
return rename_outputs_with_suffix({res}, context.get_name());
}
auto mask_node = context.get_input(1);
ov::Output<ov::Node> mask_node_sliced;
if (context.has_input("KQ_mask_sliced")) {
mask_node_sliced = context.get_input("KQ_mask_sliced");
} else {
auto token_len = get_dimensions(input_node, {1});
auto mask_node = context.get_input(1);
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
mask_node_sliced = std::make_shared<ov::op::v8::Slice>(mask_node, zero, token_len, one, one);
}
auto token_len = context.has_input("token_len") ? context.get_input("token_len") : get_dimensions(input_node, {1});
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
std::shared_ptr<ov::Node> mask_node_sliced =
std::make_shared<ov::op::v8::Slice>(mask_node, zero, token_len, one, one);
if (mask_node_sliced->get_element_type() != context.get_output_type(0)) {
if (mask_node_sliced.get_element_type() != context.get_output_type(0)) {
mask_node_sliced = std::make_shared<ov::op::v0::Convert>(mask_node_sliced, context.get_output_type(0));
}

View File

@ -36,6 +36,7 @@ namespace ggml {
using namespace ov::op;
namespace {
ov::pass::MakeStateful::ParamResPairs get_kv_param_res_pairs(
const std::shared_ptr<ov::Model>& model, const std::map<std::string, std::string>& kv_param_res_names) {
ov::pass::MakeStateful::ParamResPairs pairs;
@ -76,6 +77,16 @@ void add_token_len(TensorMap& tensor_map) {
tensor_map.insert({"token_len", token_len->output(0)});
}
void add_sliced_mask(TensorMap& tensor_map) {
auto mask = tensor_map.at("KQ_mask").get_node_shared_ptr();
auto token_len = tensor_map.at("token_len").get_node_shared_ptr();
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
std::shared_ptr<ov::Node> mask_sliced = std::make_shared<ov::op::v8::Slice>(mask, zero, token_len, one, one);
mask_sliced->set_friendly_name("KQ_mask_sliced");
tensor_map.insert({"KQ_mask_sliced", mask_sliced->output(0)});
}
void add_rope_sin_cos(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
int32_t* rope_params = ggml_model_decoder.get_rope_params();
auto inp_pos = tensor_map.at("inp_pos").get_node_shared_ptr();
@ -97,6 +108,7 @@ void add_rope_sin_cos(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
// Create common patterns
void preprocess(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
add_token_len(tensor_map);
add_sliced_mask(tensor_map);
add_rope_sin_cos(tensor_map, ggml_model_decoder);
}