Support iSWA
This commit is contained in:
parent
7d81861a18
commit
9de874cb7b
|
|
@ -30,17 +30,21 @@
|
|||
#include <set>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ggml-backend-impl.h"
|
||||
#include "ggml-backend.h"
|
||||
#include "ggml-quants.hpp"
|
||||
|
||||
GgmlOvDecoder::GgmlOvDecoder(struct ggml_tensor* node, struct ggml_cgraph* cgraph, bool is_static, bool is_first_token,
|
||||
int context_size, int num_heads, int num_heads_kv, int head_size) :
|
||||
int context_size, int context_size_swa, int num_heads, int num_heads_kv, int head_size,
|
||||
const std::vector<int>& swa_layers) :
|
||||
m_cgraph(cgraph),
|
||||
m_node(node),
|
||||
m_op_name(std::string(node->name)),
|
||||
m_context_size(context_size),
|
||||
m_context_size_swa(context_size_swa),
|
||||
m_swa_layers(swa_layers),
|
||||
m_num_heads(num_heads),
|
||||
m_num_heads_kv(num_heads_kv),
|
||||
m_head_size(head_size),
|
||||
|
|
@ -204,11 +208,14 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) {
|
|||
if (node->src[0]->op != GGML_OP_VIEW) {
|
||||
m_op_case = 1;
|
||||
} else if (ggml_is_contiguous(node->src[0])) {
|
||||
// Permute cache_k (view)
|
||||
m_op_case = 2;
|
||||
} else {
|
||||
// Permute cache_v (view), deprecated, cache_v will also fall to case 2
|
||||
m_op_case = 3;
|
||||
// Permute kv cache (view)
|
||||
std::string src_name(node->view_src->name);
|
||||
int layer = extract_layer_from_name(src_name);
|
||||
if (!is_swa_layer(layer)) {
|
||||
m_op_case = 2;
|
||||
} else {
|
||||
m_op_case = 3;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
|
@ -239,13 +246,34 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) {
|
|||
}
|
||||
}
|
||||
|
||||
int extract_layer_from_name(const std::string& name) {
|
||||
size_t pos1 = name.find("_l");
|
||||
assert(pos1 != std::string::npos);
|
||||
pos1 += 2;
|
||||
size_t pos2 = name.find(' ', pos1);
|
||||
if (pos2 == std::string::npos) {
|
||||
pos2 = name.length();
|
||||
}
|
||||
std::string layer_str = name.substr(pos1, pos2 - pos1);
|
||||
int layer = std::stoi(layer_str);
|
||||
return layer;
|
||||
}
|
||||
|
||||
void GgmlOvDecoder::set_llm_params() {
|
||||
for (int i = 0; i < m_cgraph->n_nodes; i++) {
|
||||
auto* node = m_cgraph->nodes[i];
|
||||
std::string name = std::string(node->name);
|
||||
if (node->op == GGML_OP_VIEW && std::string(node->name) == "cache_k_l0 (view)") {
|
||||
auto* cache_k = node->src[0];
|
||||
m_context_size = cache_k->ne[1];
|
||||
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
|
||||
auto* cache_k = node->src[1];
|
||||
cache_k = cache_k->view_src ? cache_k->view_src : cache_k;
|
||||
int layer = extract_layer_from_name(cache_k->name);
|
||||
|
||||
if (std::string(node->src[3]->name).find("swa") != std::string::npos) {
|
||||
m_swa_layers.push_back(layer);
|
||||
m_context_size_swa = cache_k->ne[1];
|
||||
} else {
|
||||
m_context_size = cache_k->ne[1];
|
||||
}
|
||||
} else if (node->op == GGML_OP_ROPE &&
|
||||
(name.find("Qcur-0") == 0 || std::string(node->src[0]->name).find("Qcur-0") == 0)) {
|
||||
m_head_size = node->ne[0];
|
||||
|
|
@ -269,11 +297,11 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor* src) co
|
|||
input_shape = ov::PartialShape{1, 1, 1};
|
||||
}
|
||||
} else {
|
||||
input_shape = ov::PartialShape{1, 1, ov::Dimension(1, m_context_size)};
|
||||
input_shape = ov::PartialShape{1, 1, -1};
|
||||
}
|
||||
} else if (name == "inp_out_ids" && !m_is_static) {
|
||||
input_shape = ov::PartialShape{1, 1, ov::Dimension(1, m_context_size)};
|
||||
} else if (name == "KQ_mask") {
|
||||
input_shape = ov::PartialShape{1, 1, -1};
|
||||
} else if (name.find("KQ_mask") == 0) {
|
||||
if (m_is_static) {
|
||||
if (m_is_first_token) {
|
||||
input_shape = ov::PartialShape{1, m_context_size, m_context_size};
|
||||
|
|
@ -281,13 +309,12 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor* src) co
|
|||
input_shape = ov::PartialShape{1, 1, m_context_size};
|
||||
}
|
||||
} else {
|
||||
auto max_mask_size = GGML_PAD(m_context_size, GGML_KQ_MASK_PAD);
|
||||
input_shape = ov::PartialShape{1, ov::Dimension(1, max_mask_size), ov::Dimension(1, max_mask_size)};
|
||||
input_shape = ov::PartialShape{1, -1, -1};
|
||||
}
|
||||
} else if (name.find("cache_k") == 0) {
|
||||
input_shape = ov::PartialShape{m_context_size, m_num_heads_kv, m_head_size};
|
||||
} else if (name.find("cache_v") == 0) {
|
||||
input_shape = ov::PartialShape{m_context_size, m_num_heads_kv, m_head_size};
|
||||
} else if (name.find("cache_") == 0) {
|
||||
int layer = extract_layer_from_name(name);
|
||||
bool is_swa = is_swa_layer(layer);
|
||||
input_shape = ov::PartialShape{is_swa ? m_context_size_swa : m_context_size, m_num_heads_kv, m_head_size};
|
||||
} else if (const auto* op = get_tensor_used_op(src); op && op->op == GGML_OP_SET_ROWS) {
|
||||
input_shape = ov::PartialShape{1, 1, m_is_static ? 1 : -1};
|
||||
} else if (src->op == GGML_OP_VIEW) {
|
||||
|
|
@ -305,35 +332,35 @@ void GgmlOvDecoder::add_extra_inputs() {
|
|||
// see llama_kv_cache_unified::get_n_kv and llama_kv_cache_unified::get_padding.
|
||||
// Not used for NPU
|
||||
int64_t attention_size = -1;
|
||||
int64_t attention_size_swa = -1;
|
||||
for (const auto& node : m_nodes) {
|
||||
if (node->op == GGML_OP_SOFT_MAX) {
|
||||
auto* mask = node->src[1];
|
||||
if (std::string(mask->name).find("KQ_mask") != 0) {
|
||||
throw std::runtime_error("Unexpected softmax node: " + std::string(mask->name));
|
||||
}
|
||||
attention_size = mask->ne[0];
|
||||
break;
|
||||
}
|
||||
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
|
||||
auto* mask = node->src[3];
|
||||
if (std::string(mask->name).find("KQ_mask") != 0) {
|
||||
std::string mask_name(mask->name);
|
||||
if (mask_name.find("KQ_mask") != 0) {
|
||||
throw std::runtime_error("Unexpected flash attention node: " + std::string(mask->name));
|
||||
}
|
||||
attention_size = mask->ne[0];
|
||||
if (mask_name.find("swa") != std::string::npos) {
|
||||
attention_size_swa = mask->ne[0];
|
||||
} else {
|
||||
attention_size = mask->ne[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
std::string name = "attention_size";
|
||||
auto create_attention_size_input = [this](const std::string& name, int64_t size) {
|
||||
auto param_node = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::Shape{1});
|
||||
param_node->set_friendly_name(name);
|
||||
param_node->output(0).get_tensor().set_names({name});
|
||||
m_model_extra_inputs[name] = param_node;
|
||||
|
||||
auto tensor = std::make_shared<ov::Tensor>(ov::element::i64, ov::Shape{1});
|
||||
*tensor->data<int64_t>() = attention_size;
|
||||
*tensor->data<int64_t>() = size;
|
||||
m_model_extra_input_values[name] = tensor;
|
||||
}
|
||||
};
|
||||
|
||||
create_attention_size_input("attention_size", attention_size);
|
||||
create_attention_size_input("attention_size_swa", attention_size_swa);
|
||||
}
|
||||
|
||||
const ggml_tensor* GgmlOvDecoder::get_tensor_used_op(const ggml_tensor* tensor) const {
|
||||
|
|
@ -706,8 +733,16 @@ int32_t* GgmlOvDecoder::get_output_op_params(const std::string& name) const {
|
|||
|
||||
void GgmlOvDecoder::visit_subgraph(std::function<void(std::shared_ptr<GgmlDecoder>)> node_visitor) const {
|
||||
for (const auto& node : m_nodes) {
|
||||
auto decoder = std::make_shared<GgmlOvDecoder>(
|
||||
node, m_cgraph, m_is_static, m_is_first_token, m_context_size, m_num_heads, m_num_heads_kv, m_head_size);
|
||||
auto decoder = std::make_shared<GgmlOvDecoder>(node,
|
||||
m_cgraph,
|
||||
m_is_static,
|
||||
m_is_first_token,
|
||||
m_context_size,
|
||||
m_context_size_swa,
|
||||
m_num_heads,
|
||||
m_num_heads_kv,
|
||||
m_head_size,
|
||||
m_swa_layers);
|
||||
node_visitor(decoder);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,7 +19,8 @@ public:
|
|||
|
||||
// Node decoder, called in GgmlOvDecoder::visit_subgraph
|
||||
GgmlOvDecoder(struct ggml_tensor* node, struct ggml_cgraph* cgraph, bool is_static, bool is_first_token,
|
||||
int context_size, int num_heads, int num_heads_kv, int head_size);
|
||||
int context_size, int context_size_swa, int num_heads, int num_heads_kv, int head_size,
|
||||
const std::vector<int>& swa_layers);
|
||||
|
||||
// Naive graph decoder
|
||||
GgmlOvDecoder(struct ggml_cgraph* cgraph, std::map<std::string, std::shared_ptr<ov::Node>>& model_weights);
|
||||
|
|
@ -101,6 +102,12 @@ public:
|
|||
|
||||
virtual int get_context_size() const override { return m_context_size; }
|
||||
|
||||
virtual int get_context_size_swa() const override { return m_context_size_swa; }
|
||||
|
||||
virtual int is_swa_layer(int layer) const override {
|
||||
return std::find(m_swa_layers.begin(), m_swa_layers.end(), layer) != m_swa_layers.end();
|
||||
}
|
||||
|
||||
virtual int get_num_heads() const override { return m_num_heads; }
|
||||
|
||||
virtual int get_num_heads_kv() const override { return m_num_heads_kv; }
|
||||
|
|
@ -156,6 +163,8 @@ private:
|
|||
std::map<std::string, std::shared_ptr<ov::Node>> m_model_weights;
|
||||
std::vector<std::string> m_model_output_names;
|
||||
int m_context_size;
|
||||
int m_context_size_swa;
|
||||
std::vector<int> m_swa_layers;
|
||||
int m_num_heads;
|
||||
int m_num_heads_kv;
|
||||
int m_head_size;
|
||||
|
|
@ -166,3 +175,5 @@ private:
|
|||
};
|
||||
|
||||
void print_tensor_address_map(const struct ggml_cgraph* cgraph);
|
||||
|
||||
int extract_layer_from_name(const std::string& name);
|
||||
|
|
|
|||
|
|
@ -67,6 +67,8 @@ public:
|
|||
virtual bool is_static() const = 0;
|
||||
virtual bool is_first_token() const = 0;
|
||||
virtual int get_context_size() const = 0;
|
||||
virtual int get_context_size_swa() const = 0;
|
||||
virtual int is_swa_layer(int layer) const = 0;
|
||||
};
|
||||
|
||||
} // namespace ggml
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
#include <cstdint>
|
||||
#include <openvino/frontend/node_context.hpp>
|
||||
#include <string>
|
||||
|
||||
#include "decoder.hpp"
|
||||
|
||||
|
|
@ -30,6 +31,8 @@ public:
|
|||
return m_translate_session;
|
||||
}
|
||||
|
||||
const std::vector<std::string>& get_input_names() const { return m_input_names; }
|
||||
|
||||
size_t get_input_size() const override {
|
||||
return m_decoder->get_input_size();
|
||||
}
|
||||
|
|
@ -101,15 +104,7 @@ public:
|
|||
return m_decoder->is_first_token();
|
||||
}
|
||||
|
||||
int get_num_heads() const { return m_decoder->get_num_heads(); }
|
||||
|
||||
int get_num_heads_kv() const { return m_decoder->get_num_heads_kv(); }
|
||||
|
||||
int get_head_size() const { return m_decoder->get_head_size(); }
|
||||
|
||||
int get_context_size() const { return m_decoder->get_context_size(); }
|
||||
|
||||
private:
|
||||
private:
|
||||
std::shared_ptr<GgmlDecoder> m_decoder;
|
||||
std::shared_ptr<TensorMap>& m_tensor_map;
|
||||
TranslateSession* m_translate_session;
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
#include <openvino/op/scaled_dot_product_attention.hpp>
|
||||
#include <openvino/op/transpose.hpp>
|
||||
#include <openvino/op/unsqueeze.hpp>
|
||||
#include <string>
|
||||
|
||||
#include "../node_context.hpp"
|
||||
#include "../op_table.hpp"
|
||||
|
|
@ -32,8 +33,12 @@ OutputVector translate_flash_attn_ext(const NodeContext& context) {
|
|||
auto scale_node = std::make_shared<ov::op::v0::Constant>(ov::element::f16, ov::Shape{}, std::vector<float>{scale});
|
||||
|
||||
ov::Output<ov::Node> mask_sliced;
|
||||
if (context.has_input("KQ_mask_sliced")) {
|
||||
mask_sliced = context.get_input("KQ_mask_sliced");
|
||||
std::string mask_name = "KQ_mask_sliced";
|
||||
if (context.get_input_names()[3].find("swa") != std::string::npos) {
|
||||
mask_name = "KQ_mask_swa_sliced";
|
||||
}
|
||||
if (context.has_input(mask_name)) {
|
||||
mask_sliced = context.get_input(mask_name);
|
||||
} else {
|
||||
auto token_len = get_dimensions(q, {1});
|
||||
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
|
|
|
|||
|
|
@ -29,43 +29,29 @@ OutputVector translate_permute(const NodeContext& context) {
|
|||
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
|
||||
} else {
|
||||
auto src = context.get_input(0);
|
||||
auto attention_size = context.get_input("attention_size");
|
||||
Output<Node> attention_size;
|
||||
if (context.is_static()) {
|
||||
attention_size = ov::op::v0::Constant::create(ov::element::i64, {1}, {INT_MAX});
|
||||
} else if (op_case == 2) {
|
||||
attention_size = context.get_input("attention_size");
|
||||
} else {
|
||||
attention_size = context.get_input("attention_size_swa");
|
||||
}
|
||||
|
||||
auto src_shape_ = context.get_input_shape(0).to_shape();
|
||||
std::vector<int64_t> src_shape(src_shape_.begin(), src_shape_.end());
|
||||
|
||||
std::shared_ptr<ov::Node> src_reshaped;
|
||||
if (op_case == 2) {
|
||||
src_reshaped = std::make_shared<ov::op::v1::Reshape>(
|
||||
src,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>{-1, src_shape[1], src_shape[2]}),
|
||||
false);
|
||||
} else {
|
||||
src_reshaped = std::make_shared<ov::op::v1::Reshape>(
|
||||
src,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>{src_shape[1], src_shape[0], -1}),
|
||||
false);
|
||||
}
|
||||
auto src_reshaped = std::make_shared<ov::op::v1::Reshape>(
|
||||
src,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>{-1, src_shape[1], src_shape[2]}),
|
||||
false);
|
||||
|
||||
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||
auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
|
||||
std::shared_ptr<ov::Node> slice_axis;
|
||||
if (op_case == 2) {
|
||||
slice_axis = zero;
|
||||
} else {
|
||||
slice_axis = two;
|
||||
}
|
||||
auto src_slice = std::make_shared<ov::op::v8::Slice>(src_reshaped, zero, attention_size, one, slice_axis);
|
||||
auto src_slice = std::make_shared<ov::op::v8::Slice>(src_reshaped, zero, attention_size, one, zero);
|
||||
|
||||
if (op_case == 2) {
|
||||
res = std::make_shared<ov::op::v1::Transpose>(src_slice, ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
|
||||
} else {
|
||||
res = src_slice;
|
||||
}
|
||||
res = std::make_shared<ov::op::v1::Transpose>(src_slice,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
|
||||
}
|
||||
return rename_outputs_with_suffix({res}, context.get_name());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -78,13 +78,22 @@ void add_token_len(TensorMap& tensor_map) {
|
|||
}
|
||||
|
||||
void add_sliced_mask(TensorMap& tensor_map) {
|
||||
auto mask = tensor_map.at("KQ_mask").get_node_shared_ptr();
|
||||
auto token_len = tensor_map.at("token_len").get_node_shared_ptr();
|
||||
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||
std::shared_ptr<ov::Node> mask_sliced = std::make_shared<ov::op::v8::Slice>(mask, zero, token_len, one, one);
|
||||
mask_sliced->set_friendly_name("KQ_mask_sliced");
|
||||
tensor_map.insert({"KQ_mask_sliced", mask_sliced->output(0)});
|
||||
|
||||
auto create_sliced_mask = [&](const std::string& mask_name, const std::string& sliced_name) {
|
||||
if (tensor_map.find(mask_name) != tensor_map.end()) {
|
||||
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||
auto mask = tensor_map.at(mask_name).get_node_shared_ptr();
|
||||
std::shared_ptr<ov::Node> mask_sliced =
|
||||
std::make_shared<ov::op::v8::Slice>(mask, zero, token_len, one, one);
|
||||
mask_sliced->set_friendly_name(sliced_name);
|
||||
tensor_map.insert({sliced_name, mask_sliced->output(0)});
|
||||
}
|
||||
};
|
||||
|
||||
create_sliced_mask("KQ_mask", "KQ_mask_sliced");
|
||||
create_sliced_mask("KQ_mask_swa", "KQ_mask_swa_sliced");
|
||||
}
|
||||
|
||||
void add_rope_sin_cos(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {
|
||||
|
|
|
|||
|
|
@ -362,7 +362,7 @@ ov::Tensor get_ov_input_tensor(std::shared_ptr<GgmlOvDecoder> ggml_decoder, cons
|
|||
input_tensor = convert_ggml_input_to_ov(ggml_decoder, param_name);
|
||||
}
|
||||
|
||||
} else if (param_name == "KQ_mask") {
|
||||
} else if (param_name.find("KQ_mask") == 0) {
|
||||
size_t context_size = ggml_decoder->get_context_size();
|
||||
const auto* input_tensor_ggml = ggml_decoder->get_input_ggml_tensor(param_name);
|
||||
if (is_first_token) {
|
||||
|
|
|
|||
|
|
@ -1605,7 +1605,7 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
|
|||
|
||||
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
|
||||
cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||
cb(inp->self_kq_mask, "self_kq_mask", -1);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
|
|
@ -1694,7 +1694,7 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
|
|||
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
ggml_set_name(inp->self_kq_mask, "KQ_mask");
|
||||
ggml_set_name(inp->self_kq_mask, "self_kq_mask");
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
|
|
|
|||
Loading…
Reference in New Issue