Support iSWA

This commit is contained in:
Yu, Zijun 2025-09-16 16:30:45 +08:00 committed by Mustafa Cavus
parent 7d81861a18
commit 9de874cb7b
9 changed files with 124 additions and 81 deletions

View File

@ -30,17 +30,21 @@
#include <set>
#include <stdexcept>
#include <string>
#include <vector>
#include "ggml-backend-impl.h"
#include "ggml-backend.h"
#include "ggml-quants.hpp"
GgmlOvDecoder::GgmlOvDecoder(struct ggml_tensor* node, struct ggml_cgraph* cgraph, bool is_static, bool is_first_token,
int context_size, int num_heads, int num_heads_kv, int head_size) :
int context_size, int context_size_swa, int num_heads, int num_heads_kv, int head_size,
const std::vector<int>& swa_layers) :
m_cgraph(cgraph),
m_node(node),
m_op_name(std::string(node->name)),
m_context_size(context_size),
m_context_size_swa(context_size_swa),
m_swa_layers(swa_layers),
m_num_heads(num_heads),
m_num_heads_kv(num_heads_kv),
m_head_size(head_size),
@ -204,11 +208,14 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) {
if (node->src[0]->op != GGML_OP_VIEW) {
m_op_case = 1;
} else if (ggml_is_contiguous(node->src[0])) {
// Permute cache_k (view)
m_op_case = 2;
} else {
// Permute cache_v (view), deprecated, cache_v will also fall to case 2
m_op_case = 3;
// Permute kv cache (view)
std::string src_name(node->view_src->name);
int layer = extract_layer_from_name(src_name);
if (!is_swa_layer(layer)) {
m_op_case = 2;
} else {
m_op_case = 3;
}
}
break;
}
@ -239,13 +246,34 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) {
}
}
int extract_layer_from_name(const std::string& name) {
size_t pos1 = name.find("_l");
assert(pos1 != std::string::npos);
pos1 += 2;
size_t pos2 = name.find(' ', pos1);
if (pos2 == std::string::npos) {
pos2 = name.length();
}
std::string layer_str = name.substr(pos1, pos2 - pos1);
int layer = std::stoi(layer_str);
return layer;
}
void GgmlOvDecoder::set_llm_params() {
for (int i = 0; i < m_cgraph->n_nodes; i++) {
auto* node = m_cgraph->nodes[i];
std::string name = std::string(node->name);
if (node->op == GGML_OP_VIEW && std::string(node->name) == "cache_k_l0 (view)") {
auto* cache_k = node->src[0];
m_context_size = cache_k->ne[1];
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
auto* cache_k = node->src[1];
cache_k = cache_k->view_src ? cache_k->view_src : cache_k;
int layer = extract_layer_from_name(cache_k->name);
if (std::string(node->src[3]->name).find("swa") != std::string::npos) {
m_swa_layers.push_back(layer);
m_context_size_swa = cache_k->ne[1];
} else {
m_context_size = cache_k->ne[1];
}
} else if (node->op == GGML_OP_ROPE &&
(name.find("Qcur-0") == 0 || std::string(node->src[0]->name).find("Qcur-0") == 0)) {
m_head_size = node->ne[0];
@ -269,11 +297,11 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor* src) co
input_shape = ov::PartialShape{1, 1, 1};
}
} else {
input_shape = ov::PartialShape{1, 1, ov::Dimension(1, m_context_size)};
input_shape = ov::PartialShape{1, 1, -1};
}
} else if (name == "inp_out_ids" && !m_is_static) {
input_shape = ov::PartialShape{1, 1, ov::Dimension(1, m_context_size)};
} else if (name == "KQ_mask") {
input_shape = ov::PartialShape{1, 1, -1};
} else if (name.find("KQ_mask") == 0) {
if (m_is_static) {
if (m_is_first_token) {
input_shape = ov::PartialShape{1, m_context_size, m_context_size};
@ -281,13 +309,12 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor* src) co
input_shape = ov::PartialShape{1, 1, m_context_size};
}
} else {
auto max_mask_size = GGML_PAD(m_context_size, GGML_KQ_MASK_PAD);
input_shape = ov::PartialShape{1, ov::Dimension(1, max_mask_size), ov::Dimension(1, max_mask_size)};
input_shape = ov::PartialShape{1, -1, -1};
}
} else if (name.find("cache_k") == 0) {
input_shape = ov::PartialShape{m_context_size, m_num_heads_kv, m_head_size};
} else if (name.find("cache_v") == 0) {
input_shape = ov::PartialShape{m_context_size, m_num_heads_kv, m_head_size};
} else if (name.find("cache_") == 0) {
int layer = extract_layer_from_name(name);
bool is_swa = is_swa_layer(layer);
input_shape = ov::PartialShape{is_swa ? m_context_size_swa : m_context_size, m_num_heads_kv, m_head_size};
} else if (const auto* op = get_tensor_used_op(src); op && op->op == GGML_OP_SET_ROWS) {
input_shape = ov::PartialShape{1, 1, m_is_static ? 1 : -1};
} else if (src->op == GGML_OP_VIEW) {
@ -305,35 +332,35 @@ void GgmlOvDecoder::add_extra_inputs() {
// see llama_kv_cache_unified::get_n_kv and llama_kv_cache_unified::get_padding.
// Not used for NPU
int64_t attention_size = -1;
int64_t attention_size_swa = -1;
for (const auto& node : m_nodes) {
if (node->op == GGML_OP_SOFT_MAX) {
auto* mask = node->src[1];
if (std::string(mask->name).find("KQ_mask") != 0) {
throw std::runtime_error("Unexpected softmax node: " + std::string(mask->name));
}
attention_size = mask->ne[0];
break;
}
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
auto* mask = node->src[3];
if (std::string(mask->name).find("KQ_mask") != 0) {
std::string mask_name(mask->name);
if (mask_name.find("KQ_mask") != 0) {
throw std::runtime_error("Unexpected flash attention node: " + std::string(mask->name));
}
attention_size = mask->ne[0];
if (mask_name.find("swa") != std::string::npos) {
attention_size_swa = mask->ne[0];
} else {
attention_size = mask->ne[0];
}
}
}
{
std::string name = "attention_size";
auto create_attention_size_input = [this](const std::string& name, int64_t size) {
auto param_node = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::Shape{1});
param_node->set_friendly_name(name);
param_node->output(0).get_tensor().set_names({name});
m_model_extra_inputs[name] = param_node;
auto tensor = std::make_shared<ov::Tensor>(ov::element::i64, ov::Shape{1});
*tensor->data<int64_t>() = attention_size;
*tensor->data<int64_t>() = size;
m_model_extra_input_values[name] = tensor;
}
};
create_attention_size_input("attention_size", attention_size);
create_attention_size_input("attention_size_swa", attention_size_swa);
}
const ggml_tensor* GgmlOvDecoder::get_tensor_used_op(const ggml_tensor* tensor) const {
@ -706,8 +733,16 @@ int32_t* GgmlOvDecoder::get_output_op_params(const std::string& name) const {
void GgmlOvDecoder::visit_subgraph(std::function<void(std::shared_ptr<GgmlDecoder>)> node_visitor) const {
for (const auto& node : m_nodes) {
auto decoder = std::make_shared<GgmlOvDecoder>(
node, m_cgraph, m_is_static, m_is_first_token, m_context_size, m_num_heads, m_num_heads_kv, m_head_size);
auto decoder = std::make_shared<GgmlOvDecoder>(node,
m_cgraph,
m_is_static,
m_is_first_token,
m_context_size,
m_context_size_swa,
m_num_heads,
m_num_heads_kv,
m_head_size,
m_swa_layers);
node_visitor(decoder);
}
}

View File

@ -19,7 +19,8 @@ public:
// Node decoder, called in GgmlOvDecoder::visit_subgraph
GgmlOvDecoder(struct ggml_tensor* node, struct ggml_cgraph* cgraph, bool is_static, bool is_first_token,
int context_size, int num_heads, int num_heads_kv, int head_size);
int context_size, int context_size_swa, int num_heads, int num_heads_kv, int head_size,
const std::vector<int>& swa_layers);
// Naive graph decoder
GgmlOvDecoder(struct ggml_cgraph* cgraph, std::map<std::string, std::shared_ptr<ov::Node>>& model_weights);
@ -101,6 +102,12 @@ public:
virtual int get_context_size() const override { return m_context_size; }
virtual int get_context_size_swa() const override { return m_context_size_swa; }
virtual int is_swa_layer(int layer) const override {
return std::find(m_swa_layers.begin(), m_swa_layers.end(), layer) != m_swa_layers.end();
}
virtual int get_num_heads() const override { return m_num_heads; }
virtual int get_num_heads_kv() const override { return m_num_heads_kv; }
@ -156,6 +163,8 @@ private:
std::map<std::string, std::shared_ptr<ov::Node>> m_model_weights;
std::vector<std::string> m_model_output_names;
int m_context_size;
int m_context_size_swa;
std::vector<int> m_swa_layers;
int m_num_heads;
int m_num_heads_kv;
int m_head_size;
@ -166,3 +175,5 @@ private:
};
void print_tensor_address_map(const struct ggml_cgraph* cgraph);
int extract_layer_from_name(const std::string& name);

View File

@ -67,6 +67,8 @@ public:
virtual bool is_static() const = 0;
virtual bool is_first_token() const = 0;
virtual int get_context_size() const = 0;
virtual int get_context_size_swa() const = 0;
virtual int is_swa_layer(int layer) const = 0;
};
} // namespace ggml

View File

@ -2,6 +2,7 @@
#include <cstdint>
#include <openvino/frontend/node_context.hpp>
#include <string>
#include "decoder.hpp"
@ -30,6 +31,8 @@ public:
return m_translate_session;
}
const std::vector<std::string>& get_input_names() const { return m_input_names; }
size_t get_input_size() const override {
return m_decoder->get_input_size();
}
@ -101,15 +104,7 @@ public:
return m_decoder->is_first_token();
}
int get_num_heads() const { return m_decoder->get_num_heads(); }
int get_num_heads_kv() const { return m_decoder->get_num_heads_kv(); }
int get_head_size() const { return m_decoder->get_head_size(); }
int get_context_size() const { return m_decoder->get_context_size(); }
private:
private:
std::shared_ptr<GgmlDecoder> m_decoder;
std::shared_ptr<TensorMap>& m_tensor_map;
TranslateSession* m_translate_session;

View File

@ -6,6 +6,7 @@
#include <openvino/op/scaled_dot_product_attention.hpp>
#include <openvino/op/transpose.hpp>
#include <openvino/op/unsqueeze.hpp>
#include <string>
#include "../node_context.hpp"
#include "../op_table.hpp"
@ -32,8 +33,12 @@ OutputVector translate_flash_attn_ext(const NodeContext& context) {
auto scale_node = std::make_shared<ov::op::v0::Constant>(ov::element::f16, ov::Shape{}, std::vector<float>{scale});
ov::Output<ov::Node> mask_sliced;
if (context.has_input("KQ_mask_sliced")) {
mask_sliced = context.get_input("KQ_mask_sliced");
std::string mask_name = "KQ_mask_sliced";
if (context.get_input_names()[3].find("swa") != std::string::npos) {
mask_name = "KQ_mask_swa_sliced";
}
if (context.has_input(mask_name)) {
mask_sliced = context.get_input(mask_name);
} else {
auto token_len = get_dimensions(q, {1});
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});

View File

@ -29,43 +29,29 @@ OutputVector translate_permute(const NodeContext& context) {
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
} else {
auto src = context.get_input(0);
auto attention_size = context.get_input("attention_size");
Output<Node> attention_size;
if (context.is_static()) {
attention_size = ov::op::v0::Constant::create(ov::element::i64, {1}, {INT_MAX});
} else if (op_case == 2) {
attention_size = context.get_input("attention_size");
} else {
attention_size = context.get_input("attention_size_swa");
}
auto src_shape_ = context.get_input_shape(0).to_shape();
std::vector<int64_t> src_shape(src_shape_.begin(), src_shape_.end());
std::shared_ptr<ov::Node> src_reshaped;
if (op_case == 2) {
src_reshaped = std::make_shared<ov::op::v1::Reshape>(
src,
ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>{-1, src_shape[1], src_shape[2]}),
false);
} else {
src_reshaped = std::make_shared<ov::op::v1::Reshape>(
src,
ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>{src_shape[1], src_shape[0], -1}),
false);
}
auto src_reshaped = std::make_shared<ov::op::v1::Reshape>(
src,
ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>{-1, src_shape[1], src_shape[2]}),
false);
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
std::shared_ptr<ov::Node> slice_axis;
if (op_case == 2) {
slice_axis = zero;
} else {
slice_axis = two;
}
auto src_slice = std::make_shared<ov::op::v8::Slice>(src_reshaped, zero, attention_size, one, slice_axis);
auto src_slice = std::make_shared<ov::op::v8::Slice>(src_reshaped, zero, attention_size, one, zero);
if (op_case == 2) {
res = std::make_shared<ov::op::v1::Transpose>(src_slice, ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
} else {
res = src_slice;
}
res = std::make_shared<ov::op::v1::Transpose>(src_slice,
ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
}
return rename_outputs_with_suffix({res}, context.get_name());
}

View File

@ -78,13 +78,22 @@ void add_token_len(TensorMap& tensor_map) {
}
void add_sliced_mask(TensorMap& tensor_map) {
auto mask = tensor_map.at("KQ_mask").get_node_shared_ptr();
auto token_len = tensor_map.at("token_len").get_node_shared_ptr();
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
std::shared_ptr<ov::Node> mask_sliced = std::make_shared<ov::op::v8::Slice>(mask, zero, token_len, one, one);
mask_sliced->set_friendly_name("KQ_mask_sliced");
tensor_map.insert({"KQ_mask_sliced", mask_sliced->output(0)});
auto create_sliced_mask = [&](const std::string& mask_name, const std::string& sliced_name) {
if (tensor_map.find(mask_name) != tensor_map.end()) {
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
auto mask = tensor_map.at(mask_name).get_node_shared_ptr();
std::shared_ptr<ov::Node> mask_sliced =
std::make_shared<ov::op::v8::Slice>(mask, zero, token_len, one, one);
mask_sliced->set_friendly_name(sliced_name);
tensor_map.insert({sliced_name, mask_sliced->output(0)});
}
};
create_sliced_mask("KQ_mask", "KQ_mask_sliced");
create_sliced_mask("KQ_mask_swa", "KQ_mask_swa_sliced");
}
void add_rope_sin_cos(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) {

View File

@ -362,7 +362,7 @@ ov::Tensor get_ov_input_tensor(std::shared_ptr<GgmlOvDecoder> ggml_decoder, cons
input_tensor = convert_ggml_input_to_ov(ggml_decoder, param_name);
}
} else if (param_name == "KQ_mask") {
} else if (param_name.find("KQ_mask") == 0) {
size_t context_size = ggml_decoder->get_context_size();
const auto* input_tensor_ggml = ggml_decoder->get_input_ggml_tensor(param_name);
if (is_first_token) {

View File

@ -1605,7 +1605,7 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1);
cb(inp->self_kq_mask, "KQ_mask", -1);
cb(inp->self_kq_mask, "self_kq_mask", -1);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@ -1694,7 +1694,7 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
ggml_set_name(inp->self_kq_mask, "KQ_mask");
ggml_set_name(inp->self_kq_mask, "self_kq_mask");
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;