Change openvino device_type to GPU; Enable flash_attn
This commit is contained in:
parent
65e1b1af6d
commit
56d596775d
|
|
@ -299,6 +299,13 @@ void GgmlOvDecoder::add_extra_inputs() {
|
||||||
attention_size = mask->ne[0];
|
attention_size = mask->ne[0];
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
|
||||||
|
auto* mask = node->src[3];
|
||||||
|
if (std::string(mask->name).find("KQ_mask") != 0) {
|
||||||
|
throw std::runtime_error("Unexpected flash attention node: " + std::string(mask->name));
|
||||||
|
}
|
||||||
|
attention_size = mask->ne[0];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -173,14 +173,15 @@ static void ggml_backend_openvino_device_get_memory(ggml_backend_dev_t dev, size
|
||||||
GGML_ASSERT(free != nullptr);
|
GGML_ASSERT(free != nullptr);
|
||||||
GGML_ASSERT(total != nullptr);
|
GGML_ASSERT(total != nullptr);
|
||||||
ggml_backend_openvino_device_context * ctx = (ggml_backend_openvino_device_context *)dev->context;
|
ggml_backend_openvino_device_context * ctx = (ggml_backend_openvino_device_context *)dev->context;
|
||||||
// Placeholder
|
|
||||||
GGML_ASSERT(ctx->device >= 0);
|
GGML_ASSERT(ctx->device >= 0);
|
||||||
// ggml_openvino_set_device(ctx->device);
|
// ggml_openvino_set_device(ctx->device);
|
||||||
|
*total = 1;
|
||||||
|
*free = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
static enum ggml_backend_dev_type ggml_backend_openvino_device_get_type(ggml_backend_dev_t dev) {
|
static enum ggml_backend_dev_type ggml_backend_openvino_device_get_type(ggml_backend_dev_t dev) {
|
||||||
GGML_UNUSED(dev);
|
GGML_UNUSED(dev);
|
||||||
return GGML_BACKEND_DEVICE_TYPE_ACCEL;
|
return GGML_BACKEND_DEVICE_TYPE_GPU;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_backend_openvino_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
|
static void ggml_backend_openvino_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
|
||||||
|
|
@ -293,7 +294,7 @@ static bool is_op_unsupported_case(const ggml_tensor* op) {
|
||||||
GGML_LOG_WARN("OpenVINO backend does not support ROPE with mode %d\n", mode);
|
GGML_LOG_WARN("OpenVINO backend does not support ROPE with mode %d\n", mode);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (n_dims != op->src[0]->ne[0]) {
|
if (n_dims != 0.0f && n_dims != op->src[0]->ne[0]) {
|
||||||
GGML_LOG_WARN("OpenVINO backend does not support ROPE with n_dims %d != src[0]->ne[0] %ld\n",
|
GGML_LOG_WARN("OpenVINO backend does not support ROPE with n_dims %d != src[0]->ne[0] %ld\n",
|
||||||
n_dims,
|
n_dims,
|
||||||
op->src[0]->ne[0]);
|
op->src[0]->ne[0]);
|
||||||
|
|
@ -305,7 +306,7 @@ static bool is_op_unsupported_case(const ggml_tensor* op) {
|
||||||
}
|
}
|
||||||
float freq_scale;
|
float freq_scale;
|
||||||
memcpy(&freq_scale, op_params + 6, sizeof(float));
|
memcpy(&freq_scale, op_params + 6, sizeof(float));
|
||||||
if (freq_scale != 1.0f) {
|
if (freq_scale != 0.0f && freq_scale != 1.0f) {
|
||||||
GGML_LOG_WARN("OpenVINO backend does not support ROPE with freq_scale %f != 1.0f\n", freq_scale);
|
GGML_LOG_WARN("OpenVINO backend does not support ROPE with freq_scale %f != 1.0f\n", freq_scale);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,12 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <openvino/op/broadcast.hpp>
|
||||||
|
#include <openvino/op/concat.hpp>
|
||||||
#include <openvino/op/convert.hpp>
|
#include <openvino/op/convert.hpp>
|
||||||
|
#include <openvino/op/reshape.hpp>
|
||||||
#include <openvino/op/scaled_dot_product_attention.hpp>
|
#include <openvino/op/scaled_dot_product_attention.hpp>
|
||||||
|
#include <openvino/op/transpose.hpp>
|
||||||
|
#include <openvino/op/unsqueeze.hpp>
|
||||||
|
|
||||||
#include "../node_context.hpp"
|
#include "../node_context.hpp"
|
||||||
#include "../op_table.hpp"
|
#include "../op_table.hpp"
|
||||||
#include "../utils.hpp"
|
#include "../utils.hpp"
|
||||||
|
|
@ -24,9 +30,53 @@ OutputVector translate_flash_attn_ext(const NodeContext& context) {
|
||||||
|
|
||||||
auto q = std::make_shared<ov::op::v0::Convert>(q_f32, ov::element::f16);
|
auto q = std::make_shared<ov::op::v0::Convert>(q_f32, ov::element::f16);
|
||||||
auto scale_node = std::make_shared<ov::op::v0::Constant>(ov::element::f16, ov::Shape{}, std::vector<float>{scale});
|
auto scale_node = std::make_shared<ov::op::v0::Constant>(ov::element::f16, ov::Shape{}, std::vector<float>{scale});
|
||||||
auto res = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v , mask, scale_node, false);
|
|
||||||
auto res_f32 = std::make_shared<ov::op::v0::Convert>(res, ov::element::f32);
|
ov::Output<ov::Node> mask_sliced;
|
||||||
return rename_outputs_with_suffix({res_f32}, context.get_name());
|
if (context.has_input("KQ_mask_sliced")) {
|
||||||
|
mask_sliced = context.get_input("KQ_mask_sliced");
|
||||||
|
} else {
|
||||||
|
auto token_len = get_dimensions(q, {1});
|
||||||
|
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||||
|
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||||
|
mask_sliced = std::make_shared<ov::op::v8::Slice>(mask, zero, token_len, one, one);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mask_sliced.get_element_type() != ov::element::f16) {
|
||||||
|
mask_sliced = std::make_shared<ov::op::v0::Convert>(mask_sliced, ov::element::f16);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto tile_kv = [](int64_t q_batch, int64_t kv_batch, ov::Output<Node> kv) {
|
||||||
|
int64_t factor = q_batch / kv_batch;
|
||||||
|
if (factor > 1) {
|
||||||
|
auto q_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{q_batch});
|
||||||
|
auto kv_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{kv_batch});
|
||||||
|
auto factor_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{factor});
|
||||||
|
|
||||||
|
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1});
|
||||||
|
auto kv_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(kv, unsqueeze_axes);
|
||||||
|
|
||||||
|
auto kv_last_two_dims = get_dimensions(kv.get_node_shared_ptr(), {1, 2});
|
||||||
|
auto kv_broadcast_shape =
|
||||||
|
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{kv_batch_node, factor_node, kv_last_two_dims}, 0);
|
||||||
|
kv = std::make_shared<ov::op::v3::Broadcast>(kv_unsqueezed, kv_broadcast_shape);
|
||||||
|
|
||||||
|
auto new_kv_shape =
|
||||||
|
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{q_batch_node, kv_last_two_dims}, 0);
|
||||||
|
kv = std::make_shared<ov::op::v1::Reshape>(kv, new_kv_shape, false);
|
||||||
|
}
|
||||||
|
return kv;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto q_shape = context.get_input_shape(0).to_shape();
|
||||||
|
auto k_shape = context.get_input_shape(1).to_shape();
|
||||||
|
k = tile_kv(q_shape[0], k_shape[0], k);
|
||||||
|
v = tile_kv(q_shape[0], k_shape[0], v);
|
||||||
|
|
||||||
|
auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v, mask_sliced, scale_node, false);
|
||||||
|
auto sdpa_f32 = std::make_shared<ov::op::v0::Convert>(sdpa, ov::element::f32);
|
||||||
|
auto res = std::make_shared<ov::op::v1::Transpose>(sdpa_f32,
|
||||||
|
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
|
||||||
|
return rename_outputs_with_suffix({res}, context.get_name());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
|
|
||||||
|
|
@ -62,7 +62,7 @@ OutputVector translate_mulmat(const NodeContext& context) {
|
||||||
auto B_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{B_batch});
|
auto B_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{B_batch});
|
||||||
auto factor_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{factor});
|
auto factor_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{factor});
|
||||||
|
|
||||||
auto Z_last_two_dim = get_dimensions(Z.get_node_shared_ptr(), {1, 2});
|
auto Z_last_two_dims = get_dimensions(Z.get_node_shared_ptr(), {1, 2});
|
||||||
|
|
||||||
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1});
|
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1});
|
||||||
auto Z_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(Z, unsqueeze_axes);
|
auto Z_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(Z, unsqueeze_axes);
|
||||||
|
|
@ -70,26 +70,26 @@ OutputVector translate_mulmat(const NodeContext& context) {
|
||||||
Output<Node> batch_small = A_batch_larger ? B_batch_node : A_batch_node;
|
Output<Node> batch_small = A_batch_larger ? B_batch_node : A_batch_node;
|
||||||
Output<Node> batch_large = A_batch_larger ? A_batch_node : B_batch_node;
|
Output<Node> batch_large = A_batch_larger ? A_batch_node : B_batch_node;
|
||||||
auto broadcast_shape =
|
auto broadcast_shape =
|
||||||
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{batch_small, factor_node, Z_last_two_dim}, 0);
|
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{batch_small, factor_node, Z_last_two_dims}, 0);
|
||||||
auto Z_broadcasted = std::make_shared<ov::op::v3::Broadcast>(Z_unsqueezed, broadcast_shape);
|
auto Z_broadcasted = std::make_shared<ov::op::v3::Broadcast>(Z_unsqueezed, broadcast_shape);
|
||||||
|
|
||||||
auto new_Z_shape = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{batch_large, Z_last_two_dim}, 0);
|
auto new_Z_shape = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{batch_large, Z_last_two_dims}, 0);
|
||||||
Z = std::make_shared<ov::op::v1::Reshape>(Z_broadcasted, new_Z_shape, false);
|
Z = std::make_shared<ov::op::v1::Reshape>(Z_broadcasted, new_Z_shape, false);
|
||||||
}
|
}
|
||||||
if (A_batch_larger) {
|
if (A_batch_larger) {
|
||||||
B = Z;
|
B = Z;
|
||||||
} else {
|
} else {
|
||||||
A = Z;
|
A = Z;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (convert_out_type) {
|
if (convert_out_type) {
|
||||||
auto result_lp = std::make_shared<ov::op::v0::MatMul>(A, B, false, transpose_b);
|
auto result_lp = std::make_shared<ov::op::v0::MatMul>(A, B, false, transpose_b);
|
||||||
res = std::make_shared<ov::op::v0::Convert>(result_lp, context.get_output_type(0));
|
res = std::make_shared<ov::op::v0::Convert>(result_lp, context.get_output_type(0));
|
||||||
} else {
|
} else {
|
||||||
res = std::make_shared<ov::op::v0::MatMul>(A, B, false, transpose_b);
|
res = std::make_shared<ov::op::v0::MatMul>(A, B, false, transpose_b);
|
||||||
}
|
}
|
||||||
|
|
||||||
return rename_outputs_with_suffix({res}, context.get_name());
|
return rename_outputs_with_suffix({res}, context.get_name());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
|
|
||||||
|
|
@ -51,14 +51,18 @@ OutputVector translate_soft_max(const NodeContext& context) {
|
||||||
return rename_outputs_with_suffix({res}, context.get_name());
|
return rename_outputs_with_suffix({res}, context.get_name());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto mask_node = context.get_input(1);
|
ov::Output<ov::Node> mask_node_sliced;
|
||||||
|
if (context.has_input("KQ_mask_sliced")) {
|
||||||
|
mask_node_sliced = context.get_input("KQ_mask_sliced");
|
||||||
|
} else {
|
||||||
|
auto token_len = get_dimensions(input_node, {1});
|
||||||
|
auto mask_node = context.get_input(1);
|
||||||
|
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||||
|
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||||
|
mask_node_sliced = std::make_shared<ov::op::v8::Slice>(mask_node, zero, token_len, one, one);
|
||||||
|
}
|
||||||
|
|
||||||
auto token_len = context.has_input("token_len") ? context.get_input("token_len") : get_dimensions(input_node, {1});
|
if (mask_node_sliced.get_element_type() != context.get_output_type(0)) {
|
||||||
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
|
||||||
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
|
||||||
std::shared_ptr<ov::Node> mask_node_sliced =
|
|
||||||
std::make_shared<ov::op::v8::Slice>(mask_node, zero, token_len, one, one);
|
|
||||||
if (mask_node_sliced->get_element_type() != context.get_output_type(0)) {
|
|
||||||
mask_node_sliced = std::make_shared<ov::op::v0::Convert>(mask_node_sliced, context.get_output_type(0));
|
mask_node_sliced = std::make_shared<ov::op::v0::Convert>(mask_node_sliced, context.get_output_type(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,7 @@ namespace ggml {
|
||||||
using namespace ov::op;
|
using namespace ov::op;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
ov::pass::MakeStateful::ParamResPairs get_kv_param_res_pairs(
|
ov::pass::MakeStateful::ParamResPairs get_kv_param_res_pairs(
|
||||||
const std::shared_ptr<ov::Model>& model, const std::map<std::string, std::string>& kv_param_res_names) {
|
const std::shared_ptr<ov::Model>& model, const std::map<std::string, std::string>& kv_param_res_names) {
|
||||||
ov::pass::MakeStateful::ParamResPairs pairs;
|
ov::pass::MakeStateful::ParamResPairs pairs;
|
||||||
|
|
@ -76,6 +77,16 @@ void add_token_len(TensorMap& tensor_map) {
|
||||||
tensor_map.insert({"token_len", token_len->output(0)});
|
tensor_map.insert({"token_len", token_len->output(0)});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void add_sliced_mask(TensorMap& tensor_map) {
|
||||||
|
auto mask = tensor_map.at("KQ_mask").get_node_shared_ptr();
|
||||||
|
auto token_len = tensor_map.at("token_len").get_node_shared_ptr();
|
||||||
|
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||||
|
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||||
|
std::shared_ptr<ov::Node> mask_sliced = std::make_shared<ov::op::v8::Slice>(mask, zero, token_len, one, one);
|
||||||
|
mask_sliced->set_friendly_name("KQ_mask_sliced");
|
||||||
|
tensor_map.insert({"KQ_mask_sliced", mask_sliced->output(0)});
|
||||||
|
}
|
||||||
|
|
||||||
void add_rope_sin_cos(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
|
void add_rope_sin_cos(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
|
||||||
int32_t* rope_params = ggml_model_decoder.get_rope_params();
|
int32_t* rope_params = ggml_model_decoder.get_rope_params();
|
||||||
auto inp_pos = tensor_map.at("inp_pos").get_node_shared_ptr();
|
auto inp_pos = tensor_map.at("inp_pos").get_node_shared_ptr();
|
||||||
|
|
@ -97,6 +108,7 @@ void add_rope_sin_cos(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
|
||||||
// Create common patterns
|
// Create common patterns
|
||||||
void preprocess(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
|
void preprocess(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
|
||||||
add_token_len(tensor_map);
|
add_token_len(tensor_map);
|
||||||
|
add_sliced_mask(tensor_map);
|
||||||
add_rope_sin_cos(tensor_map, ggml_model_decoder);
|
add_rope_sin_cos(tensor_map, ggml_model_decoder);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue