Change openvino device_type to GPU; Enable flash_attn
This commit is contained in:
parent
65e1b1af6d
commit
56d596775d
|
|
@ -299,6 +299,13 @@ void GgmlOvDecoder::add_extra_inputs() {
|
|||
attention_size = mask->ne[0];
|
||||
break;
|
||||
}
|
||||
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
|
||||
auto* mask = node->src[3];
|
||||
if (std::string(mask->name).find("KQ_mask") != 0) {
|
||||
throw std::runtime_error("Unexpected flash attention node: " + std::string(mask->name));
|
||||
}
|
||||
attention_size = mask->ne[0];
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
|
|
|
|||
|
|
@ -173,14 +173,15 @@ static void ggml_backend_openvino_device_get_memory(ggml_backend_dev_t dev, size
|
|||
GGML_ASSERT(free != nullptr);
|
||||
GGML_ASSERT(total != nullptr);
|
||||
ggml_backend_openvino_device_context * ctx = (ggml_backend_openvino_device_context *)dev->context;
|
||||
// Placeholder
|
||||
GGML_ASSERT(ctx->device >= 0);
|
||||
// ggml_openvino_set_device(ctx->device);
|
||||
*total = 1;
|
||||
*free = 1;
|
||||
}
|
||||
|
||||
static enum ggml_backend_dev_type ggml_backend_openvino_device_get_type(ggml_backend_dev_t dev) {
|
||||
GGML_UNUSED(dev);
|
||||
return GGML_BACKEND_DEVICE_TYPE_ACCEL;
|
||||
return GGML_BACKEND_DEVICE_TYPE_GPU;
|
||||
}
|
||||
|
||||
static void ggml_backend_openvino_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
|
||||
|
|
@ -293,7 +294,7 @@ static bool is_op_unsupported_case(const ggml_tensor* op) {
|
|||
GGML_LOG_WARN("OpenVINO backend does not support ROPE with mode %d\n", mode);
|
||||
return true;
|
||||
}
|
||||
if (n_dims != op->src[0]->ne[0]) {
|
||||
if (n_dims != 0.0f && n_dims != op->src[0]->ne[0]) {
|
||||
GGML_LOG_WARN("OpenVINO backend does not support ROPE with n_dims %d != src[0]->ne[0] %ld\n",
|
||||
n_dims,
|
||||
op->src[0]->ne[0]);
|
||||
|
|
@ -305,7 +306,7 @@ static bool is_op_unsupported_case(const ggml_tensor* op) {
|
|||
}
|
||||
float freq_scale;
|
||||
memcpy(&freq_scale, op_params + 6, sizeof(float));
|
||||
if (freq_scale != 1.0f) {
|
||||
if (freq_scale != 0.0f && freq_scale != 1.0f) {
|
||||
GGML_LOG_WARN("OpenVINO backend does not support ROPE with freq_scale %f != 1.0f\n", freq_scale);
|
||||
return true;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,12 @@
|
|||
#include <memory>
|
||||
#include <openvino/op/broadcast.hpp>
|
||||
#include <openvino/op/concat.hpp>
|
||||
#include <openvino/op/convert.hpp>
|
||||
#include <openvino/op/reshape.hpp>
|
||||
#include <openvino/op/scaled_dot_product_attention.hpp>
|
||||
#include <openvino/op/transpose.hpp>
|
||||
#include <openvino/op/unsqueeze.hpp>
|
||||
|
||||
#include "../node_context.hpp"
|
||||
#include "../op_table.hpp"
|
||||
#include "../utils.hpp"
|
||||
|
|
@ -24,9 +30,53 @@ OutputVector translate_flash_attn_ext(const NodeContext& context) {
|
|||
|
||||
auto q = std::make_shared<ov::op::v0::Convert>(q_f32, ov::element::f16);
|
||||
auto scale_node = std::make_shared<ov::op::v0::Constant>(ov::element::f16, ov::Shape{}, std::vector<float>{scale});
|
||||
auto res = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v , mask, scale_node, false);
|
||||
auto res_f32 = std::make_shared<ov::op::v0::Convert>(res, ov::element::f32);
|
||||
return rename_outputs_with_suffix({res_f32}, context.get_name());
|
||||
|
||||
ov::Output<ov::Node> mask_sliced;
|
||||
if (context.has_input("KQ_mask_sliced")) {
|
||||
mask_sliced = context.get_input("KQ_mask_sliced");
|
||||
} else {
|
||||
auto token_len = get_dimensions(q, {1});
|
||||
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||
mask_sliced = std::make_shared<ov::op::v8::Slice>(mask, zero, token_len, one, one);
|
||||
}
|
||||
|
||||
if (mask_sliced.get_element_type() != ov::element::f16) {
|
||||
mask_sliced = std::make_shared<ov::op::v0::Convert>(mask_sliced, ov::element::f16);
|
||||
}
|
||||
|
||||
auto tile_kv = [](int64_t q_batch, int64_t kv_batch, ov::Output<Node> kv) {
|
||||
int64_t factor = q_batch / kv_batch;
|
||||
if (factor > 1) {
|
||||
auto q_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{q_batch});
|
||||
auto kv_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{kv_batch});
|
||||
auto factor_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{factor});
|
||||
|
||||
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1});
|
||||
auto kv_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(kv, unsqueeze_axes);
|
||||
|
||||
auto kv_last_two_dims = get_dimensions(kv.get_node_shared_ptr(), {1, 2});
|
||||
auto kv_broadcast_shape =
|
||||
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{kv_batch_node, factor_node, kv_last_two_dims}, 0);
|
||||
kv = std::make_shared<ov::op::v3::Broadcast>(kv_unsqueezed, kv_broadcast_shape);
|
||||
|
||||
auto new_kv_shape =
|
||||
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{q_batch_node, kv_last_two_dims}, 0);
|
||||
kv = std::make_shared<ov::op::v1::Reshape>(kv, new_kv_shape, false);
|
||||
}
|
||||
return kv;
|
||||
};
|
||||
|
||||
auto q_shape = context.get_input_shape(0).to_shape();
|
||||
auto k_shape = context.get_input_shape(1).to_shape();
|
||||
k = tile_kv(q_shape[0], k_shape[0], k);
|
||||
v = tile_kv(q_shape[0], k_shape[0], v);
|
||||
|
||||
auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v, mask_sliced, scale_node, false);
|
||||
auto sdpa_f32 = std::make_shared<ov::op::v0::Convert>(sdpa, ov::element::f32);
|
||||
auto res = std::make_shared<ov::op::v1::Transpose>(sdpa_f32,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
|
||||
return rename_outputs_with_suffix({res}, context.get_name());
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ OutputVector translate_mulmat(const NodeContext& context) {
|
|||
auto B_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{B_batch});
|
||||
auto factor_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{factor});
|
||||
|
||||
auto Z_last_two_dim = get_dimensions(Z.get_node_shared_ptr(), {1, 2});
|
||||
auto Z_last_two_dims = get_dimensions(Z.get_node_shared_ptr(), {1, 2});
|
||||
|
||||
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1});
|
||||
auto Z_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(Z, unsqueeze_axes);
|
||||
|
|
@ -70,26 +70,26 @@ OutputVector translate_mulmat(const NodeContext& context) {
|
|||
Output<Node> batch_small = A_batch_larger ? B_batch_node : A_batch_node;
|
||||
Output<Node> batch_large = A_batch_larger ? A_batch_node : B_batch_node;
|
||||
auto broadcast_shape =
|
||||
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{batch_small, factor_node, Z_last_two_dim}, 0);
|
||||
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{batch_small, factor_node, Z_last_two_dims}, 0);
|
||||
auto Z_broadcasted = std::make_shared<ov::op::v3::Broadcast>(Z_unsqueezed, broadcast_shape);
|
||||
|
||||
auto new_Z_shape = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{batch_large, Z_last_two_dim}, 0);
|
||||
auto new_Z_shape = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{batch_large, Z_last_two_dims}, 0);
|
||||
Z = std::make_shared<ov::op::v1::Reshape>(Z_broadcasted, new_Z_shape, false);
|
||||
}
|
||||
if (A_batch_larger) {
|
||||
B = Z;
|
||||
} else {
|
||||
A = Z;
|
||||
}
|
||||
}
|
||||
if (A_batch_larger) {
|
||||
B = Z;
|
||||
} else {
|
||||
A = Z;
|
||||
}
|
||||
|
||||
if (convert_out_type) {
|
||||
auto result_lp = std::make_shared<ov::op::v0::MatMul>(A, B, false, transpose_b);
|
||||
res = std::make_shared<ov::op::v0::Convert>(result_lp, context.get_output_type(0));
|
||||
} else {
|
||||
res = std::make_shared<ov::op::v0::MatMul>(A, B, false, transpose_b);
|
||||
}
|
||||
if (convert_out_type) {
|
||||
auto result_lp = std::make_shared<ov::op::v0::MatMul>(A, B, false, transpose_b);
|
||||
res = std::make_shared<ov::op::v0::Convert>(result_lp, context.get_output_type(0));
|
||||
} else {
|
||||
res = std::make_shared<ov::op::v0::MatMul>(A, B, false, transpose_b);
|
||||
}
|
||||
|
||||
return rename_outputs_with_suffix({res}, context.get_name());
|
||||
return rename_outputs_with_suffix({res}, context.get_name());
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
|
|
|
|||
|
|
@ -51,14 +51,18 @@ OutputVector translate_soft_max(const NodeContext& context) {
|
|||
return rename_outputs_with_suffix({res}, context.get_name());
|
||||
}
|
||||
|
||||
auto mask_node = context.get_input(1);
|
||||
ov::Output<ov::Node> mask_node_sliced;
|
||||
if (context.has_input("KQ_mask_sliced")) {
|
||||
mask_node_sliced = context.get_input("KQ_mask_sliced");
|
||||
} else {
|
||||
auto token_len = get_dimensions(input_node, {1});
|
||||
auto mask_node = context.get_input(1);
|
||||
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||
mask_node_sliced = std::make_shared<ov::op::v8::Slice>(mask_node, zero, token_len, one, one);
|
||||
}
|
||||
|
||||
auto token_len = context.has_input("token_len") ? context.get_input("token_len") : get_dimensions(input_node, {1});
|
||||
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||
std::shared_ptr<ov::Node> mask_node_sliced =
|
||||
std::make_shared<ov::op::v8::Slice>(mask_node, zero, token_len, one, one);
|
||||
if (mask_node_sliced->get_element_type() != context.get_output_type(0)) {
|
||||
if (mask_node_sliced.get_element_type() != context.get_output_type(0)) {
|
||||
mask_node_sliced = std::make_shared<ov::op::v0::Convert>(mask_node_sliced, context.get_output_type(0));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ namespace ggml {
|
|||
using namespace ov::op;
|
||||
|
||||
namespace {
|
||||
|
||||
ov::pass::MakeStateful::ParamResPairs get_kv_param_res_pairs(
|
||||
const std::shared_ptr<ov::Model>& model, const std::map<std::string, std::string>& kv_param_res_names) {
|
||||
ov::pass::MakeStateful::ParamResPairs pairs;
|
||||
|
|
@ -76,6 +77,16 @@ void add_token_len(TensorMap& tensor_map) {
|
|||
tensor_map.insert({"token_len", token_len->output(0)});
|
||||
}
|
||||
|
||||
void add_sliced_mask(TensorMap& tensor_map) {
|
||||
auto mask = tensor_map.at("KQ_mask").get_node_shared_ptr();
|
||||
auto token_len = tensor_map.at("token_len").get_node_shared_ptr();
|
||||
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||
std::shared_ptr<ov::Node> mask_sliced = std::make_shared<ov::op::v8::Slice>(mask, zero, token_len, one, one);
|
||||
mask_sliced->set_friendly_name("KQ_mask_sliced");
|
||||
tensor_map.insert({"KQ_mask_sliced", mask_sliced->output(0)});
|
||||
}
|
||||
|
||||
void add_rope_sin_cos(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
|
||||
int32_t* rope_params = ggml_model_decoder.get_rope_params();
|
||||
auto inp_pos = tensor_map.at("inp_pos").get_node_shared_ptr();
|
||||
|
|
@ -97,6 +108,7 @@ void add_rope_sin_cos(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
|
|||
// Create common patterns
|
||||
void preprocess(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
|
||||
add_token_len(tensor_map);
|
||||
add_sliced_mask(tensor_map);
|
||||
add_rope_sin_cos(tensor_map, ggml_model_decoder);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue