change graph to 4d, support multi sequences
This commit is contained in:
parent
ea2c99be1c
commit
072dde0b2b
|
|
@ -44,26 +44,26 @@ GgmlOvDecoder::GgmlOvDecoder(ggml_tensor * node,
|
|||
int num_heads_kv,
|
||||
int head_size,
|
||||
const std::vector<int> & swa_layers) :
|
||||
m_is_static(is_static),
|
||||
m_cgraph(cgraph),
|
||||
m_node(node),
|
||||
m_op_name(std::string(node->name)),
|
||||
m_context_size(context_size),
|
||||
m_context_size_swa(context_size_swa),
|
||||
m_swa_layers(swa_layers),
|
||||
m_num_heads(num_heads),
|
||||
m_num_heads_kv(num_heads_kv),
|
||||
m_ctx(context_size),
|
||||
m_ctx_swa(context_size_swa),
|
||||
m_n_heads(num_heads),
|
||||
m_n_heads_kv(num_heads_kv),
|
||||
m_head_size(head_size),
|
||||
m_is_static(is_static) {
|
||||
m_swa_layers(swa_layers) {
|
||||
set_input_output(node);
|
||||
}
|
||||
|
||||
GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph,
|
||||
std::map<std::string, std::shared_ptr<ov::Node>> & model_weights,
|
||||
bool is_static) :
|
||||
m_is_static(is_static),
|
||||
m_cgraph(cgraph),
|
||||
m_op_name(m_node ? std::string(m_node->name) : ""),
|
||||
m_model_weights(model_weights),
|
||||
m_is_static(is_static) {
|
||||
m_model_weights(model_weights) {
|
||||
if (auto * env = getenv("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS"); env && std::string(env) != "0") {
|
||||
unsetenv("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS");
|
||||
print_tensor_address_map(cgraph);
|
||||
|
|
@ -78,7 +78,7 @@ GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph,
|
|||
set_input_output(cur_node);
|
||||
}
|
||||
|
||||
// add_extra_inputs();
|
||||
add_extra_inputs();
|
||||
}
|
||||
|
||||
GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph, std::map<std::string, std::shared_ptr<ov::Node>> & model_weights) {
|
||||
|
|
@ -125,7 +125,8 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) {
|
|||
// Add model inputs and weights constants, if called for the whole graph
|
||||
if (naive) {
|
||||
if (m_model_weights.find(src_name) == m_model_weights.end()) {
|
||||
auto param_node = std::make_shared<ov::op::v0::Parameter>(get_ov_type(src), get_graph_input_shape(src));
|
||||
auto param_node =
|
||||
std::make_shared<ov::op::v0::Parameter>(get_ov_type(src), get_graph_input_shape(node, src));
|
||||
param_node->set_friendly_name(src_name);
|
||||
param_node->output(0).get_tensor().set_names({src_name});
|
||||
m_model_inputs[src_name] = param_node;
|
||||
|
|
@ -142,7 +143,8 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) {
|
|||
if (m_model_inputs.find(src_name) != m_model_inputs.end()) {
|
||||
continue;
|
||||
}
|
||||
auto param_node = std::make_shared<ov::op::v0::Parameter>(get_ov_type(src), get_graph_input_shape(src));
|
||||
auto param_node =
|
||||
std::make_shared<ov::op::v0::Parameter>(get_ov_type(src), get_graph_input_shape(node, src));
|
||||
param_node->set_friendly_name(src_name);
|
||||
param_node->output(0).get_tensor().set_names({src_name});
|
||||
m_model_inputs[src_name] = param_node;
|
||||
|
|
@ -175,15 +177,20 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) {
|
|||
if (m_node) {
|
||||
switch (node->op) {
|
||||
case GGML_OP_RESHAPE: {
|
||||
if (node->src[0]->op == GGML_OP_RESHAPE && node->src[0]->src[0]->ne[0] == node->ne[0] &&
|
||||
node->src[0]->src[0]->ne[1] == node->ne[1]) {
|
||||
auto * src = node->src[0];
|
||||
if (src->op == GGML_OP_RESHAPE && src->src[0]->ne[0] == node->ne[0] && src->src[0]->ne[1] == node->ne[1]) {
|
||||
m_op_case = 4;
|
||||
} else if (node->ne[0] * node->ne[1] == node->src[0]->ne[0]) {
|
||||
} else if (node->ne[0] * node->ne[1] == src->ne[0]) {
|
||||
m_op_case = 1;
|
||||
} else if (node->src[0]->ne[0] * node->src[0]->ne[1] == node->ne[0]) {
|
||||
} else if (src->ne[0] * src->ne[1] == node->ne[0]) {
|
||||
m_op_case = 2;
|
||||
} else if (node->src[0]->ne[0] * node->src[0]->ne[1] == node->ne[1]) {
|
||||
if (src->ne[2] * src->ne[3] == node->ne[1]) {
|
||||
m_op_case = 5;
|
||||
}
|
||||
} else if (src->ne[0] * src->ne[1] == node->ne[1]) {
|
||||
m_op_case = 3;
|
||||
} else if (src->ne[1] * src->ne[2] == node->ne[1]) {
|
||||
m_op_case = 6;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
|
@ -204,7 +211,8 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) {
|
|||
} else if (ggml_is_contiguous(node->src[0])) {
|
||||
std::string src_name(node->view_src->name);
|
||||
if (src_name.find("cache") == std::string::npos) {
|
||||
m_op_case = 1;
|
||||
// permute Qcur
|
||||
m_op_case = 4;
|
||||
} else {
|
||||
// Permute kv cache (view)
|
||||
int layer = extract_layer_from_name(src_name);
|
||||
|
|
@ -241,10 +249,10 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) {
|
|||
case GGML_OP_VIEW: {
|
||||
if (node->src[0]->op == GGML_OP_VIEW) {
|
||||
auto * src = node->src[0];
|
||||
auto * view_src = src->view_src;
|
||||
if (view_src->ne[1] != src->ne[2]) {
|
||||
if (ggml_nelements(node) != ggml_nelements(src)) {
|
||||
throw std::runtime_error("Unsupported VIEW case");
|
||||
}
|
||||
// This view is a reshape, slicing happens at src->op
|
||||
m_op_case = 2;
|
||||
}
|
||||
}
|
||||
|
|
@ -272,64 +280,80 @@ void GgmlOvDecoder::set_llm_params() {
|
|||
auto * node = m_cgraph->nodes[i];
|
||||
std::string name = std::string(node->name);
|
||||
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
|
||||
auto * cache_k = node->src[1];
|
||||
cache_k = cache_k->view_src ? cache_k->view_src : cache_k;
|
||||
auto * cache_k_perm = node->src[1];
|
||||
assert(cache_k_perm->op == GGML_OP_PERMUTE);
|
||||
auto * cache_k_view = cache_k_perm->src[0];
|
||||
assert(cache_k_view->op == GGML_OP_VIEW);
|
||||
|
||||
auto * cache_k = cache_k_view->src[0];
|
||||
int layer = extract_layer_from_name(cache_k->name);
|
||||
auto * mask = node->src[3];
|
||||
std::string mask_name(mask->name);
|
||||
assert(mask_name.find("KQ_mask") == 0);
|
||||
|
||||
if (std::string(node->src[3]->name).find("swa") != std::string::npos) {
|
||||
m_swa_layers.push_back(layer);
|
||||
m_context_size_swa = cache_k->ne[1];
|
||||
m_ctx_per_seq_swa = cache_k->ne[1];
|
||||
} else {
|
||||
m_context_size = cache_k->ne[1];
|
||||
m_ctx_per_seq = cache_k->ne[1];
|
||||
m_n_seq = cache_k->ne[2];
|
||||
}
|
||||
|
||||
m_n_seq_active = mask->ne[3];
|
||||
auto seq_size = cache_k->ne[0] * cache_k->ne[1] * ggml_type_size(cache_k->type);
|
||||
m_seq_active_start = ((size_t *) cache_k_view->op_params)[0] / seq_size;
|
||||
m_token_len_per_seq = node->ne[2];
|
||||
|
||||
if (mask_name.find("swa") != std::string::npos) {
|
||||
m_attention_size_swa = mask->ne[0];
|
||||
} else {
|
||||
m_attention_size = mask->ne[0];
|
||||
}
|
||||
|
||||
} else if (node->op == GGML_OP_ROPE) {
|
||||
if (name.find("Qcur-0") == 0 || std::string(node->src[0]->name).find("Qcur-0") == 0) {
|
||||
m_head_size = node->ne[0];
|
||||
m_num_heads = node->ne[1];
|
||||
m_n_heads = node->ne[1];
|
||||
m_rope_params = node->op_params;
|
||||
auto * inp_pos = node->src[1];
|
||||
m_input_len = inp_pos->ne[0];
|
||||
m_past_kv_len = *(int32_t *) inp_pos->data;
|
||||
} else if (name.find("Kcur-0") == 0 || std::string(node->src[0]->name).find("Kcur-0") == 0) {
|
||||
m_num_heads_kv = node->ne[1];
|
||||
m_n_heads_kv = node->ne[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
m_ctx = m_ctx_per_seq * m_n_seq;
|
||||
m_ctx_swa = m_ctx_per_seq_swa * m_n_seq;
|
||||
}
|
||||
|
||||
void GgmlOvDecoder::validate_cgraph() const {}
|
||||
void GgmlOvDecoder::validate_cgraph() const {
|
||||
if (m_n_seq > 1 && m_is_static == true) {
|
||||
throw std::runtime_error("n_seq > 1 is not supported on NPU");
|
||||
}
|
||||
}
|
||||
|
||||
ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor * src) const {
|
||||
auto name = std::string(src->name);
|
||||
ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor * op, const ggml_tensor * input) const {
|
||||
auto name = std::string(input->name);
|
||||
ov::PartialShape input_shape;
|
||||
|
||||
if (name == "inp_tokens" || name == "inp_pos" || name == "inp_out_ids") {
|
||||
input_shape = ov::PartialShape{1, 1, m_is_static ? 1 : -1};
|
||||
input_shape = ov::PartialShape{1, 1, 1, m_is_static ? 1 : -1};
|
||||
|
||||
} else if (name.find("KQ_mask") == 0) {
|
||||
if (m_is_static) {
|
||||
input_shape = ov::PartialShape{1, 1, m_context_size};
|
||||
input_shape = ov::PartialShape{1, 1, 1, m_ctx};
|
||||
} else {
|
||||
input_shape = ov::PartialShape{1, -1, -1};
|
||||
input_shape = ov::PartialShape{-1, 1, -1, -1};
|
||||
}
|
||||
|
||||
} else if (name.find("cache_") == 0) {
|
||||
auto past_token_len = -1;
|
||||
if (m_is_static) {
|
||||
int layer = extract_layer_from_name(name);
|
||||
bool is_swa = is_swa_layer(layer);
|
||||
past_token_len = is_swa ? m_context_size_swa : m_context_size;
|
||||
}
|
||||
input_shape = ov::PartialShape{past_token_len, m_num_heads_kv, m_head_size};
|
||||
} else if (op && op->op == GGML_OP_SET_ROWS && op->src[1] == input) {
|
||||
input_shape = ov::PartialShape{1, 1, 1, m_is_static ? 1 : -1};
|
||||
|
||||
} else if (const auto * op = get_tensor_used_op(src); op && op->op == GGML_OP_SET_ROWS) {
|
||||
input_shape = ov::PartialShape{1, 1, m_is_static ? 1 : -1};
|
||||
|
||||
} else if (src->op == GGML_OP_VIEW) {
|
||||
} else if (input->op == GGML_OP_VIEW) {
|
||||
// This case is added to make test-backend-ops work
|
||||
input_shape = ov::PartialShape{get_shape(src->view_src)};
|
||||
input_shape = ov::PartialShape{get_shape(input->view_src)};
|
||||
} else {
|
||||
input_shape = ov::PartialShape{get_shape(src)};
|
||||
input_shape = ov::PartialShape{get_shape(input)};
|
||||
}
|
||||
return input_shape;
|
||||
}
|
||||
|
|
@ -339,25 +363,9 @@ void GgmlOvDecoder::add_extra_inputs() {
|
|||
// 1. `attention_size`, used in FLASH_ATTN where the shape of the matmul's are 256 aligned,
|
||||
// see llama_kv_cache_unified::get_n_kv and llama_kv_cache_unified::get_padding.
|
||||
// Not used for NPU.
|
||||
// Update: not used anymore after the optimization of making kvcache dynamic (but breaks iSWA models)
|
||||
int64_t attention_size = -1;
|
||||
int64_t attention_size_swa = -1;
|
||||
for (const auto & node : m_nodes) {
|
||||
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
|
||||
auto * mask = node->src[3];
|
||||
std::string mask_name(mask->name);
|
||||
if (mask_name.find("KQ_mask") != 0) {
|
||||
throw std::runtime_error("Unexpected flash attention node: " + std::string(mask->name));
|
||||
}
|
||||
if (mask_name.find("swa") != std::string::npos) {
|
||||
attention_size_swa = mask->ne[0];
|
||||
} else {
|
||||
attention_size = mask->ne[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
// 2. `n_seq_active` and `seq_active_start`, used in FLASH_ATTN_EXT to indicate the active sequences in the batch
|
||||
|
||||
auto create_attention_size_input = [this](const std::string & name, int64_t size) {
|
||||
auto create_1d_input = [this](const std::string & name, int64_t size) {
|
||||
auto param_node = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::Shape{1});
|
||||
param_node->set_friendly_name(name);
|
||||
param_node->output(0).get_tensor().set_names({name});
|
||||
|
|
@ -368,10 +376,15 @@ void GgmlOvDecoder::add_extra_inputs() {
|
|||
m_model_extra_input_values[name] = tensor;
|
||||
};
|
||||
|
||||
create_attention_size_input("attention_size", attention_size);
|
||||
if (attention_size_swa != -1) {
|
||||
create_attention_size_input("attention_size_swa", attention_size_swa);
|
||||
create_1d_input("attention_size", m_attention_size);
|
||||
if (m_attention_size_swa != -1) {
|
||||
create_1d_input("attention_size_swa", m_attention_size_swa);
|
||||
}
|
||||
create_1d_input("n_seq_active", m_n_seq_active);
|
||||
create_1d_input("seq_active_start", m_seq_active_start);
|
||||
create_1d_input("seq_active_end", m_seq_active_start + m_n_seq_active);
|
||||
create_1d_input("token_len_per_seq", m_token_len_per_seq);
|
||||
// create_1d_input("token_len", m_token_len_per_seq * m_n_seq_active);
|
||||
}
|
||||
|
||||
const ggml_tensor * GgmlOvDecoder::get_tensor_used_op(const ggml_tensor * tensor) const {
|
||||
|
|
@ -472,6 +485,8 @@ std::shared_ptr<ov::Node> GgmlOvDecoder::create_weight_node(ggml_tensor * tensor
|
|||
auto node_shape = get_shape(tensor);
|
||||
auto ne_total = ggml_nelements(tensor);
|
||||
|
||||
OPENVINO_ASSERT(node_shape[0] == 1, "Got 4D weights, expect all weights to be 2D: ", tensor->name);
|
||||
node_shape.erase(node_shape.begin());
|
||||
OPENVINO_ASSERT(node_shape[0] == 1, "Got 3D weights, expect all weights to be 2D: ", tensor->name);
|
||||
node_shape.erase(node_shape.begin());
|
||||
|
||||
|
|
@ -641,7 +656,7 @@ void print_tensor_address_map(const ggml_cgraph * cgraph) {
|
|||
|
||||
std::vector<size_t> GgmlOvDecoder::get_shape(const ggml_tensor * tensor) {
|
||||
std::vector<size_t> shape;
|
||||
for (int i = GGML_MAX_DIMS - 2; i >= 0; --i) {
|
||||
for (int i = GGML_MAX_DIMS - 1; i >= 0; --i) {
|
||||
shape.push_back(static_cast<size_t>(tensor->ne[i]));
|
||||
}
|
||||
return shape;
|
||||
|
|
@ -649,7 +664,7 @@ std::vector<size_t> GgmlOvDecoder::get_shape(const ggml_tensor * tensor) {
|
|||
|
||||
std::vector<size_t> GgmlOvDecoder::get_stride(const ggml_tensor * tensor) {
|
||||
std::vector<size_t> stride;
|
||||
for (int i = GGML_MAX_DIMS - 2; i >= 0; --i) {
|
||||
for (int i = GGML_MAX_DIMS - 1; i >= 0; --i) {
|
||||
stride.push_back(static_cast<size_t>(tensor->nb[i]));
|
||||
}
|
||||
return stride;
|
||||
|
|
@ -708,7 +723,11 @@ std::vector<size_t> GgmlOvDecoder::get_output_stride(const std::string & name) c
|
|||
}
|
||||
|
||||
ov::PartialShape GgmlOvDecoder::get_output_shape(const std::string & name) const {
|
||||
return ov::PartialShape(get_shape(m_outputs.at(name)));
|
||||
auto * ggml_tensor = m_outputs.at(name);
|
||||
if (ggml_tensor->op == GGML_OP_SET_ROWS) {
|
||||
ggml_tensor = ggml_tensor->view_src;
|
||||
}
|
||||
return ov::PartialShape(get_shape(ggml_tensor));
|
||||
}
|
||||
|
||||
ov::element::Type GgmlOvDecoder::get_output_type(const std::string & name) const {
|
||||
|
|
@ -738,8 +757,8 @@ int32_t * GgmlOvDecoder::get_output_op_params(const std::string & name) const {
|
|||
|
||||
void GgmlOvDecoder::visit_subgraph(std::function<void(std::shared_ptr<GgmlDecoder>)> node_visitor) const {
|
||||
for (const auto & node : m_nodes) {
|
||||
auto decoder = std::make_shared<GgmlOvDecoder>(node, m_cgraph, m_is_static, m_context_size, m_context_size_swa,
|
||||
m_num_heads, m_num_heads_kv, m_head_size, m_swa_layers);
|
||||
auto decoder = std::make_shared<GgmlOvDecoder>(node, m_cgraph, m_is_static, m_ctx, m_ctx_swa, m_n_heads,
|
||||
m_n_heads_kv, m_head_size, m_swa_layers);
|
||||
node_visitor(decoder);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -103,20 +103,20 @@ public:
|
|||
|
||||
virtual const std::vector<std::string> & get_model_output_names() const override { return m_model_output_names; }
|
||||
|
||||
virtual int get_context_size() const override { return m_context_size; }
|
||||
virtual int get_ctx_size() const { return m_ctx; }
|
||||
|
||||
virtual int get_context_size_swa() const override { return m_context_size_swa; }
|
||||
virtual int get_ctx_swa_size() const { return m_ctx_swa; }
|
||||
|
||||
virtual int get_ctx_per_seq() const { return m_ctx_per_seq; }
|
||||
|
||||
virtual int get_ctx_per_seq_swa() const { return m_ctx_per_seq_swa; }
|
||||
|
||||
virtual int get_n_seq() const { return m_n_seq; }
|
||||
|
||||
virtual int is_swa_layer(int layer) const override {
|
||||
return std::find(m_swa_layers.begin(), m_swa_layers.end(), layer) != m_swa_layers.end();
|
||||
}
|
||||
|
||||
virtual int get_num_heads() const override { return m_num_heads; }
|
||||
|
||||
virtual int get_num_heads_kv() const override { return m_num_heads_kv; }
|
||||
|
||||
virtual int get_head_size() const override { return m_head_size; }
|
||||
|
||||
int get_past_kv_len() const { return m_past_kv_len; }
|
||||
|
||||
int get_input_len() const { return m_input_len; }
|
||||
|
|
@ -127,7 +127,7 @@ public:
|
|||
|
||||
virtual bool is_static() const override { return m_is_static; }
|
||||
|
||||
ov::PartialShape get_graph_input_shape(const ggml_tensor * src) const;
|
||||
ov::PartialShape get_graph_input_shape(const ggml_tensor * op, const ggml_tensor * input) const;
|
||||
|
||||
static void dump_cgraph(const ggml_cgraph * cgraph, std::string & filename);
|
||||
|
||||
|
|
@ -151,10 +151,11 @@ private:
|
|||
static std::vector<size_t> get_stride(const ggml_tensor * tensor);
|
||||
static ov::element::Type get_ov_type(const ggml_tensor * tensor);
|
||||
|
||||
// set context_size, num_heads, etc
|
||||
void set_llm_params();
|
||||
void validate_cgraph() const;
|
||||
|
||||
bool m_is_static = false;
|
||||
|
||||
ggml_cgraph * m_cgraph = nullptr;
|
||||
ggml_tensor * m_node = nullptr;
|
||||
std::vector<ggml_tensor *> m_nodes;
|
||||
|
|
@ -171,17 +172,28 @@ private:
|
|||
std::map<std::string, std::shared_ptr<ov::Tensor>> m_model_extra_input_values;
|
||||
std::map<std::string, std::shared_ptr<ov::Node>> m_model_weights;
|
||||
std::vector<std::string> m_model_output_names;
|
||||
int m_context_size;
|
||||
int m_context_size_swa;
|
||||
|
||||
// Fixed for a model
|
||||
int m_ctx = -1;
|
||||
int m_ctx_swa = -1;
|
||||
int m_ctx_per_seq = -1;
|
||||
int m_ctx_per_seq_swa = -1;
|
||||
int m_n_seq = -1;
|
||||
int m_n_heads = -1;
|
||||
int m_n_heads_kv = -1;
|
||||
int m_head_size = -1;
|
||||
std::vector<int> m_swa_layers;
|
||||
int m_num_heads;
|
||||
int m_num_heads_kv;
|
||||
int m_head_size;
|
||||
int m_past_kv_len;
|
||||
int m_input_len;
|
||||
int32_t * m_rope_params;
|
||||
std::vector<std::string> m_kv_names;
|
||||
bool m_is_static = false;
|
||||
|
||||
// Changed per inference
|
||||
int m_n_seq_active = -1;
|
||||
int m_seq_active_start = -1;
|
||||
int m_attention_size = -1;
|
||||
int m_attention_size_swa = -1;
|
||||
int m_input_len = -1;
|
||||
int m_token_len_per_seq = -1;
|
||||
int m_past_kv_len = -1;
|
||||
int32_t * m_rope_params = nullptr;
|
||||
};
|
||||
|
||||
void print_tensor_address_map(const ggml_cgraph * cgraph);
|
||||
|
|
|
|||
|
|
@ -329,10 +329,6 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con
|
|||
GGML_LOG_WARN("OpenVINO backend does not support tensor type %s\n", ggml_type_name(op->type));
|
||||
return false;
|
||||
}
|
||||
if (op->ne[3] != 1) {
|
||||
GGML_LOG_WARN("OpenVINO backend does not support tensors with ne[3] != 1\n");
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
auto * src = op->src[i];
|
||||
if (src == nullptr) {
|
||||
|
|
@ -342,10 +338,6 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con
|
|||
GGML_LOG_WARN("OpenVINO backend does not support tensor type %s\n", ggml_type_name(src->type));
|
||||
return false;
|
||||
}
|
||||
if (src->ne[3] != 1) {
|
||||
GGML_LOG_WARN("OpenVINO backend does not support tensors with ne[3] != 1\n");
|
||||
return false;
|
||||
}
|
||||
if (ggml_is_quantized(src->type) && src->ne[2] != 1) {
|
||||
GGML_LOG_WARN("OpenVINO backend does not support 3D quantized tensors\n");
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -58,15 +58,11 @@ public:
|
|||
virtual const std::map<std::string, std::shared_ptr<ov::Node>>& get_model_weights() const = 0;
|
||||
virtual const std::vector<std::string>& get_model_output_names() const = 0;
|
||||
|
||||
virtual int get_num_heads() const = 0;
|
||||
virtual int get_num_heads_kv() const = 0;
|
||||
virtual int get_head_size() const = 0;
|
||||
virtual int32_t* get_rope_params() const = 0;
|
||||
virtual std::map<std::string, std::string> get_kv_param_res_names() const = 0;
|
||||
|
||||
virtual bool is_static() const = 0;
|
||||
virtual int get_context_size() const = 0;
|
||||
virtual int get_context_size_swa() const = 0;
|
||||
|
||||
virtual int is_swa_layer(int layer) const = 0;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ OutputVector translate_cont(const NodeContext & context) {
|
|||
|
||||
if (op_case == 1) {
|
||||
// The input comes from a PERMUTE
|
||||
throw std::runtime_error("Code of this case might be outdated");
|
||||
dst_shape[1] = -1;
|
||||
res = std::make_shared<ov::op::v1::Reshape>(
|
||||
context.get_input(0), ov::op::v0::Constant::create(ov::element::i64, {dst_shape.size()}, dst_shape), false);
|
||||
|
|
|
|||
|
|
@ -42,17 +42,11 @@ OutputVector translate_flash_attn_ext(const NodeContext & context) {
|
|||
if (context.has_input(mask_name)) {
|
||||
mask_sliced = context.get_input(mask_name);
|
||||
} else {
|
||||
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||
auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
|
||||
auto token_len = get_dimensions(q, {2});
|
||||
auto kv_len = get_dimensions(k.get_node_shared_ptr(), {2});
|
||||
|
||||
auto zero_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 0});
|
||||
auto one_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {1, 1});
|
||||
auto zero_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
auto two_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
|
||||
auto axes = ov::op::v0::Constant::create(ov::element::i64, {2}, {1, 2});
|
||||
|
||||
auto stop = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{token_len, kv_len}, 0);
|
||||
mask_sliced = std::make_shared<ov::op::v8::Slice>(mask, zero_2d, stop, one_2d, axes);
|
||||
mask_sliced = std::make_shared<ov::op::v8::Slice>(mask, zero, token_len, one, two);
|
||||
}
|
||||
|
||||
if (mask_sliced.get_element_type() != ov::element::f16) {
|
||||
|
|
@ -63,27 +57,29 @@ OutputVector translate_flash_attn_ext(const NodeContext & context) {
|
|||
int64_t factor = num_heads / num_heads_kv;
|
||||
if (factor > 1) {
|
||||
ov::Output<ov::Node> kv_broadcast_shape, kv_unsqueezed, new_kv_shape;
|
||||
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {1});
|
||||
auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {2});
|
||||
kv_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(kv, unsqueeze_axes);
|
||||
|
||||
kv_broadcast_shape =
|
||||
ov::op::v0::Constant::create(ov::element::i64, {4}, {num_heads_kv, factor, (int64_t) 1, head_size});
|
||||
new_kv_shape = ov::op::v0::Constant::create(ov::element::i64, {3}, {num_heads, (int64_t) -1, head_size});
|
||||
kv_broadcast_shape = ov::op::v0::Constant::create(
|
||||
ov::element::i64, {5}, {(int64_t) 1, num_heads_kv, factor, (int64_t) 1, head_size});
|
||||
new_kv_shape =
|
||||
ov::op::v0::Constant::create(ov::element::i64, {4}, {(int64_t) 0, num_heads, (int64_t) -1, head_size});
|
||||
|
||||
kv = std::make_shared<ov::op::v3::Broadcast>(kv_unsqueezed, kv_broadcast_shape,
|
||||
ov::op::BroadcastType::BIDIRECTIONAL);
|
||||
kv = std::make_shared<ov::op::v1::Reshape>(kv, new_kv_shape, false);
|
||||
kv = std::make_shared<ov::op::v1::Reshape>(kv, new_kv_shape, true);
|
||||
}
|
||||
return kv;
|
||||
};
|
||||
|
||||
auto q_shape = context.get_input_shape(0).to_shape();
|
||||
auto k_shape = context.get_input_shape(1).to_shape();
|
||||
k = tile_kv(q_shape[0], k_shape[0], q_shape[2], k);
|
||||
v = tile_kv(q_shape[0], k_shape[0], q_shape[2], v);
|
||||
k = tile_kv(q_shape[1], k_shape[1], q_shape[3], k);
|
||||
v = tile_kv(q_shape[1], k_shape[1], q_shape[3], v);
|
||||
|
||||
auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v, mask_sliced, scale_node, false);
|
||||
res = std::make_shared<ov::op::v1::Transpose>(sdpa, ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
|
||||
res = std::make_shared<ov::op::v1::Transpose>(sdpa,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3}));
|
||||
res = std::make_shared<ov::op::v0::Convert>(res, ov::element::f32);
|
||||
return rename_outputs_with_suffix({res}, context.get_name());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
#include <openvino/op/convert.hpp>
|
||||
#include <openvino/op/gather.hpp>
|
||||
#include <openvino/op/squeeze.hpp>
|
||||
#include <openvino/op/unsqueeze.hpp>
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
|
|
@ -28,11 +29,13 @@ OutputVector translate_get_rows(const NodeContext & context) {
|
|||
indices = process_view_input(context, 1);
|
||||
}
|
||||
|
||||
// data[b,x,y] ind[1,b,x'] test-backend-ops case
|
||||
// data[x,y] ind[1,1,x'] normal case
|
||||
indices = std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));
|
||||
if (data.get_partial_shape().rank() == 3) {
|
||||
// data[1,b,x,y] ind[1,1,b,x'] test-backend-ops case
|
||||
// data[x,y] ind[1,1,1,x'] normal case
|
||||
indices =
|
||||
std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1}));
|
||||
if (data.get_partial_shape().rank() == 4) {
|
||||
auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {1});
|
||||
data = std::make_shared<ov::op::v0::Squeeze>(data, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));
|
||||
res = std::make_shared<ov::op::v8::Gather>(data, indices, axis, 1);
|
||||
} else {
|
||||
auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {0});
|
||||
|
|
@ -42,6 +45,7 @@ OutputVector translate_get_rows(const NodeContext & context) {
|
|||
if (res.get_element_type() != context.get_output_type(0)) {
|
||||
res = std::make_shared<ov::op::v0::Convert>(res, context.get_output_type(0));
|
||||
}
|
||||
res = std::make_shared<ov::op::v0::Unsqueeze>(res, ov::op::v0::Constant::create(ov::element::i64, {1}, {0}));
|
||||
return rename_outputs_with_suffix({res}, context.get_name());
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ OutputVector translate_glu_geglu(const NodeContext & context) {
|
|||
src1 = context.get_input(1);
|
||||
} else {
|
||||
auto combined = context.get_input(0);
|
||||
auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {2});
|
||||
auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {3});
|
||||
auto split = std::make_shared<ov::op::v1::Split>(combined, split_axis, 2);
|
||||
src0 = split->output(0);
|
||||
src1 = split->output(1);
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ OutputVector translate_glu_swiglu(const NodeContext & context) {
|
|||
src1 = context.get_input(1);
|
||||
} else {
|
||||
auto combined = context.get_input(0);
|
||||
auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {2});
|
||||
auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {3});
|
||||
auto split = std::make_shared<ov::op::v1::Split>(combined, split_axis, 2);
|
||||
src0 = split->output(0);
|
||||
src1 = split->output(1);
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ OutputVector translate_mulmat(const NodeContext & context) {
|
|||
Output<Node> Z = A_batch_larger ? B : A;
|
||||
int64_t factor = A_batch_larger ? A_batch / B_batch : B_batch / A_batch;
|
||||
if (factor > 1) {
|
||||
// TODO code is outdated
|
||||
auto A_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{A_batch});
|
||||
auto B_batch_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{B_batch});
|
||||
auto factor_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{factor});
|
||||
|
|
|
|||
|
|
@ -6,12 +6,12 @@
|
|||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <openvino/core/node.hpp>
|
||||
#include <openvino/op/add.hpp>
|
||||
#include <openvino/op/concat.hpp>
|
||||
#include <openvino/op/constant.hpp>
|
||||
#include <openvino/op/reshape.hpp>
|
||||
#include <openvino/op/slice.hpp>
|
||||
#include <openvino/op/transpose.hpp>
|
||||
#include <openvino/op/unsqueeze.hpp>
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
|
|
@ -22,12 +22,64 @@ OutputVector translate_permute(const NodeContext & context) {
|
|||
num_inputs_check(context, 1, 1);
|
||||
|
||||
int op_case = context.get_op_case();
|
||||
FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2 || op_case == 3, "Unsupported PERMUTE case");
|
||||
ov::Output<Node> res;
|
||||
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2 || op_case == 3 || op_case == 4,
|
||||
"Unsupported PERMUTE case");
|
||||
|
||||
ov::Output<Node> res;
|
||||
auto src = context.get_input(0);
|
||||
res = std::make_shared<ov::op::v1::Transpose>(src, ov::op::v0::Constant::create(ov::element::i64, {3}, {1, 0, 2}));
|
||||
auto perm = ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3});
|
||||
|
||||
if (op_case == 1) {
|
||||
res = std::make_shared<ov::op::v1::Transpose>(src, perm);
|
||||
} else if (op_case == 4) {
|
||||
auto output_shape = context.get_output_shape(0).to_shape();
|
||||
auto n_heads = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[1]});
|
||||
auto head_size = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[3]});
|
||||
auto n_seq_active = context.get_input("n_seq_active");
|
||||
auto neg_one = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1});
|
||||
|
||||
auto new_shape =
|
||||
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{n_seq_active, neg_one, n_heads, head_size}, 0);
|
||||
|
||||
// // Alternative
|
||||
// auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
// auto new_shape = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{n_seq_active, neg_one, zero, zero}, 0);
|
||||
|
||||
auto reshaped = std::make_shared<ov::op::v1::Reshape>(src, new_shape, true);
|
||||
res = std::make_shared<ov::op::v1::Transpose>(reshaped, perm);
|
||||
} else {
|
||||
auto cache_shape = src.get_partial_shape();
|
||||
auto output_shape = context.get_output_shape(0).to_shape();
|
||||
int64_t head_size = output_shape[3];
|
||||
int64_t n_heads = output_shape[1];
|
||||
int64_t ctx_per_seq = cache_shape[2].get_length();
|
||||
int64_t n_seq = cache_shape[1].get_length();
|
||||
|
||||
Output<Node> attention_size;
|
||||
if (context.is_static()) {
|
||||
attention_size = ov::op::v0::Constant::create(ov::element::i64, {1}, {INT_MAX});
|
||||
} else if (op_case == 2) {
|
||||
attention_size = context.get_input("attention_size");
|
||||
} else {
|
||||
attention_size = context.get_input("attention_size_swa");
|
||||
}
|
||||
|
||||
// 1. reshape to [n_seq, ctx_per_seq, n_heads, head_size]
|
||||
// 2. slice out the active sequences
|
||||
// 3. slice out the attention part in each sequence
|
||||
// 4. permute
|
||||
auto seq_active_start = context.get_input("seq_active_start");
|
||||
auto seq_active_end = context.get_input("seq_active_end");
|
||||
|
||||
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||
|
||||
auto src_reshaped = std::make_shared<ov::op::v1::Reshape>(
|
||||
src, ov::op::v0::Constant::create(ov::element::i64, {4}, {n_seq, ctx_per_seq, n_heads, head_size}), false);
|
||||
auto slice1 = std::make_shared<ov::op::v8::Slice>(src_reshaped, seq_active_start, seq_active_end, one, zero);
|
||||
auto slice2 = std::make_shared<ov::op::v8::Slice>(slice1, zero, attention_size, one, one);
|
||||
res = std::make_shared<ov::op::v1::Transpose>(slice2, perm);
|
||||
}
|
||||
return rename_outputs_with_suffix({res}, context.get_name());
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -7,8 +7,10 @@
|
|||
#include <openvino/core/node.hpp>
|
||||
#include <openvino/core/node_output.hpp>
|
||||
#include <openvino/frontend/exception.hpp>
|
||||
#include <openvino/op/concat.hpp>
|
||||
#include <openvino/op/constant.hpp>
|
||||
#include <openvino/op/reshape.hpp>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
|
||||
namespace ov {
|
||||
|
|
@ -23,22 +25,43 @@ OutputVector translate_reshape(const NodeContext & context) {
|
|||
}
|
||||
|
||||
int op_case = context.get_op_case();
|
||||
FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2 || op_case == 3 || op_case == 4,
|
||||
"Unsupported RESHAPE case");
|
||||
FRONT_END_CHECK_IMPLEMENTED(
|
||||
op_case == 1 || op_case == 2 || op_case == 3 || op_case == 4 || op_case == 5 || op_case == 6,
|
||||
"Unsupported RESHAPE case");
|
||||
|
||||
auto output_shape = context.get_output_shape(0).to_shape();
|
||||
std::shared_ptr<ov::Node> new_shape_node;
|
||||
if (op_case == 1) {
|
||||
new_shape_node = ov::op::v0::Constant::create(
|
||||
ov::element::i64, {3}, std::vector<int64_t>{-1, (int64_t) output_shape[1], (int64_t) output_shape[2]});
|
||||
ov::element::i64, {4},
|
||||
std::vector<int64_t>{(int64_t) output_shape[0], -1, (int64_t) output_shape[2], (int64_t) output_shape[3]});
|
||||
|
||||
} else if (op_case == 2) {
|
||||
new_shape_node = ov::op::v0::Constant::create(
|
||||
ov::element::i64, {3}, std::vector<int64_t>{(int64_t) output_shape[0], -1, (int64_t) output_shape[2]});
|
||||
ov::element::i64, {4},
|
||||
std::vector<int64_t>{(int64_t) output_shape[0], (int64_t) output_shape[1], -1, (int64_t) output_shape[3]});
|
||||
|
||||
} else if (op_case == 3) {
|
||||
new_shape_node =
|
||||
ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>{(int64_t) output_shape[0], -1, 1});
|
||||
throw std::runtime_error("might be outdated RESHAPE case");
|
||||
new_shape_node = ov::op::v0::Constant::create(
|
||||
ov::element::i64, {4}, std::vector<int64_t>{(int64_t) output_shape[0], (int64_t) output_shape[1], -1, 1});
|
||||
|
||||
} else if (op_case == 4) {
|
||||
return {context.get_input(0).get_node_shared_ptr()->input_value(0)};
|
||||
|
||||
} else if (op_case == 5) {
|
||||
std::vector<int64_t> shape_vec = {1, 1, -1, (int64_t) context.get_output_shape(0).to_shape()[3]};
|
||||
new_shape_node = ov::op::v0::Constant::create(ov::element::i64, {4}, shape_vec);
|
||||
|
||||
// // Alternative
|
||||
// auto token_len = context.get_input("token_len");
|
||||
// auto emb_size =
|
||||
// ov::op::v0::Constant::create(ov::element::i64, {1}, {(int64_t) context.get_output_shape(0).to_shape()[3]});
|
||||
// auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||
// new_shape_node = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{one, one, token_len, emb_size}, 0);
|
||||
|
||||
} else if (op_case == 6) {
|
||||
new_shape_node = ov::op::v0::Constant::create(ov::element::i64, {4}, context.get_output_shape(0).to_shape());
|
||||
}
|
||||
auto res = std::make_shared<ov::op::v1::Reshape>(context.get_input(0), new_shape_node, false);
|
||||
return rename_outputs_with_suffix({res}, context.get_name());
|
||||
|
|
|
|||
|
|
@ -52,10 +52,10 @@ OutputVector translate_rope(const NodeContext & context) {
|
|||
|
||||
if (op_case == 2) {
|
||||
// The input comes from a VIEW
|
||||
int slice_len = output_shape[1] * output_shape[2];
|
||||
int slice_len = output_shape[2] * output_shape[3];
|
||||
data_node = process_view_input(context, 0, slice_len).get_node_shared_ptr();
|
||||
auto data_shape = ov::op::v0::Constant::create(
|
||||
ov::element::i64, {3}, std::vector<int64_t>{-1, (int64_t) output_shape[1], (int64_t) output_shape[2]});
|
||||
ov::element::i64, {4}, std::vector<int64_t>{1, -1, (int64_t) output_shape[2], (int64_t) output_shape[3]});
|
||||
data_node = std::make_shared<ov::op::v1::Reshape>(data_node, data_shape, false);
|
||||
}
|
||||
|
||||
|
|
@ -67,9 +67,10 @@ OutputVector translate_rope(const NodeContext & context) {
|
|||
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||
auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
|
||||
auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[2]});
|
||||
auto even_slice = std::make_shared<ov::op::v8::Slice>(data_node, zero, end, two, two);
|
||||
auto odd_slice = std::make_shared<ov::op::v8::Slice>(data_node, one, end, two, two);
|
||||
auto three = ov::op::v0::Constant::create(ov::element::i64, {1}, {3});
|
||||
auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[3]});
|
||||
auto even_slice = std::make_shared<ov::op::v8::Slice>(data_node, zero, end, two, three);
|
||||
auto odd_slice = std::make_shared<ov::op::v8::Slice>(data_node, one, end, two, three);
|
||||
|
||||
Output<Node> first_half =
|
||||
std::make_shared<ov::op::v1::Subtract>(std::make_shared<ov::op::v1::Multiply>(even_slice, cos_theta_node),
|
||||
|
|
@ -79,14 +80,17 @@ OutputVector translate_rope(const NodeContext & context) {
|
|||
std::make_shared<ov::op::v1::Multiply>(odd_slice, cos_theta_node));
|
||||
|
||||
first_half = std::make_shared<ov::op::v0::Unsqueeze>(first_half,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {1}, {3}));
|
||||
ov::op::v0::Constant::create(ov::element::i64, {1}, {4}));
|
||||
second_half = std::make_shared<ov::op::v0::Unsqueeze>(second_half,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {1}, {3}));
|
||||
auto stack = std::make_shared<ov::op::v0::Concat>(OutputVector{first_half, second_half}, 3);
|
||||
res = std::make_shared<ov::op::v1::Reshape>(stack, std::make_shared<ov::op::v0::ShapeOf>(data_node), false);
|
||||
ov::op::v0::Constant::create(ov::element::i64, {1}, {4}));
|
||||
auto stack = std::make_shared<ov::op::v0::Concat>(OutputVector{first_half, second_half}, 4);
|
||||
|
||||
auto data_shape = ov::op::v0::Constant::create(
|
||||
ov::element::i64, {4}, std::vector<int64_t>{1, -1, (int64_t) output_shape[2], (int64_t) output_shape[3]});
|
||||
res = std::make_shared<ov::op::v1::Reshape>(stack, data_shape, false);
|
||||
} else if (mode == ROPE_TYPE_NEOX) {
|
||||
auto data_split = std::make_shared<ov::op::v1::Split>(
|
||||
data_node, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {2}), 2);
|
||||
data_node, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {3}), 2);
|
||||
Output<Node> slice_data_node_0 = data_split->outputs()[0];
|
||||
Output<Node> slice_data_node_1 = data_split->outputs()[1];
|
||||
|
||||
|
|
@ -98,7 +102,7 @@ OutputVector translate_rope(const NodeContext & context) {
|
|||
std::make_shared<ov::op::v1::Multiply>(slice_data_node_0, sin_theta_node),
|
||||
std::make_shared<ov::op::v1::Multiply>(slice_data_node_1, cos_theta_node));
|
||||
|
||||
res = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{first_half_node, second_half_node}, 2);
|
||||
res = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{first_half_node, second_half_node}, 3);
|
||||
}
|
||||
|
||||
return rename_outputs_with_suffix({res}, context.get_name());
|
||||
|
|
|
|||
|
|
@ -28,33 +28,28 @@ OutputVector translate_set_rows(const NodeContext & context) {
|
|||
num_inputs_check(context, 3, 3);
|
||||
|
||||
auto data = context.get_input(0);
|
||||
auto indices = context.get_input(1);
|
||||
auto dst = context.get_input(2);
|
||||
|
||||
data = std::make_shared<ov::op::v0::Convert>(data, context.get_output_type(0));
|
||||
|
||||
auto dst_shape = context.get_output_shape(0).to_shape();
|
||||
FRONT_END_OP_CONVERSION_CHECK(dst_shape[0] == 1, "Unsupported shape in SET_ROWS");
|
||||
|
||||
auto indices = context.get_input(1);
|
||||
auto dst = context.get_input(context.get_output_name());
|
||||
auto ind_squeezed =
|
||||
std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create(ov::element::i64, {3}, {0, 1, 2}));
|
||||
auto data_reshaped = std::make_shared<ov::op::v1::Reshape>(
|
||||
data,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {4},
|
||||
{(int64_t) 1, (int64_t) 1, (int64_t) -1, (int64_t) dst_shape[3]}),
|
||||
false);
|
||||
auto axes = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {2});
|
||||
|
||||
auto zero = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {0});
|
||||
Output<Node> res;
|
||||
if (context.is_static()) {
|
||||
auto dst_reshaped = std::make_shared<ov::op::v1::Reshape>(
|
||||
dst, ov::op::v0::Constant::create(ov::element::i64, {2}, {(int64_t) dst_shape[1], (int64_t) dst_shape[2]}),
|
||||
false);
|
||||
auto indices_reshaped =
|
||||
std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1}));
|
||||
auto data_reshaped = std::make_shared<ov::op::v1::Reshape>(
|
||||
data, ov::op::v0::Constant::create(ov::element::i64, {2}, {(int64_t) -1, (int64_t) dst_shape[2]}), false);
|
||||
Output<Node> res = std::make_shared<ov::op::v3::ScatterUpdate>(dst, ind_squeezed, data_reshaped, axes);
|
||||
|
||||
auto updated = std::make_shared<ov::op::v3::ScatterUpdate>(dst_reshaped, indices_reshaped, data_reshaped, zero);
|
||||
res = std::make_shared<ov::op::v1::Reshape>(updated, std::make_shared<ov::op::v0::ShapeOf>(dst), false);
|
||||
} else {
|
||||
int64_t dim1 = dst.get_partial_shape()[1].get_length();
|
||||
int64_t dim2 = dst.get_partial_shape()[2].get_length();
|
||||
data = std::make_shared<ov::op::v1::Reshape>(
|
||||
data, ov::op::v0::Constant::create(ov::element::i64, {3}, {(int64_t) -1, dim1, dim2}), false);
|
||||
res = std::make_shared<ov::op::v0::Concat>(OutputVector{dst, data}, 0);
|
||||
if (auto dst_reshape = std::dynamic_pointer_cast<ov::op::v1::Reshape>(dst.get_node_shared_ptr())) {
|
||||
// Fix the case of multiple sequences, reshape back to original shape [1, n_seq, ctx_per_seq, emb]
|
||||
res = std::make_shared<ov::op::v1::Reshape>(
|
||||
res, ov::op::v0::Constant::create(ov::element::i64, {4}, dst_reshape->get_input_shape(0)), false);
|
||||
}
|
||||
return rename_outputs_with_suffix({res}, context.get_name());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ namespace ggml {
|
|||
namespace op {
|
||||
|
||||
OutputVector translate_soft_max(const NodeContext & context) {
|
||||
// TODO code is outdated
|
||||
num_inputs_check(context, 1, 2);
|
||||
|
||||
auto input_node = context.get_input(0).get_node_shared_ptr();
|
||||
|
|
|
|||
|
|
@ -12,8 +12,8 @@ namespace op {
|
|||
OutputVector translate_transpose(const NodeContext & context) {
|
||||
num_inputs_check(context, 1, 1);
|
||||
|
||||
auto res = std::make_shared<ov::op::v1::Transpose>(context.get_input(0),
|
||||
ov::op::v0::Constant::create(ov::element::i64, {3}, {0, 2, 1}));
|
||||
auto res = std::make_shared<ov::op::v1::Transpose>(
|
||||
context.get_input(0), ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 1, 3, 2}));
|
||||
return rename_outputs_with_suffix({res}, context.get_name());
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ OutputVector translate_view(const NodeContext & context) {
|
|||
|
||||
if (context.get_op_case() == 2) {
|
||||
auto dst_shape = context.get_output_shape(0).to_shape();
|
||||
return rename_outputs_with_suffix({process_view_input(context, 0, dst_shape[1] * dst_shape[2])},
|
||||
return rename_outputs_with_suffix({process_view_input(context, 0, dst_shape[2] * dst_shape[3])},
|
||||
context.get_name());
|
||||
}
|
||||
return {context.get_input(0)};
|
||||
|
|
|
|||
|
|
@ -72,15 +72,8 @@ ov::pass::MakeStateful::ParamResPairs get_kv_param_res_pairs(
|
|||
return pairs;
|
||||
}
|
||||
|
||||
void add_token_len(TensorMap & tensor_map) {
|
||||
auto inp_tokens = tensor_map.at("inp_tokens").get_node_shared_ptr();
|
||||
auto token_len = get_dimensions(inp_tokens, {2});
|
||||
token_len->set_friendly_name("token_len");
|
||||
tensor_map.insert({"token_len", token_len->output(0)});
|
||||
}
|
||||
|
||||
void add_sliced_mask(TensorMap & tensor_map, GgmlDecoder & ggml_model_decoder) {
|
||||
auto token_len = tensor_map.at("token_len").get_node_shared_ptr();
|
||||
auto token_len_per_seq = tensor_map.at("token_len_per_seq").get_node_shared_ptr();
|
||||
|
||||
auto create_sliced_mask = [&](const std::string & mask_name, const std::string & sliced_name, bool is_static) {
|
||||
if (tensor_map.find(mask_name) != tensor_map.end()) {
|
||||
|
|
@ -89,28 +82,10 @@ void add_sliced_mask(TensorMap & tensor_map, GgmlDecoder & ggml_model_decoder) {
|
|||
if (is_static) {
|
||||
mask_sliced = mask;
|
||||
} else {
|
||||
auto zero_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 0});
|
||||
auto one_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {1, 1});
|
||||
auto one_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||
auto zero_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
auto two_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
|
||||
auto axes = ov::op::v0::Constant::create(ov::element::i64, {2}, {1, 2});
|
||||
|
||||
std::shared_ptr<ov::Node> kv_len;
|
||||
{
|
||||
auto start = ov::op::v0::Constant::create(element::i64, Shape{3}, {0, 0, -1});
|
||||
auto stride = ov::op::v0::Constant::create(element::i64, Shape{3}, {1, 1, 1});
|
||||
auto inp_pos = tensor_map.at("inp_pos").get_node_shared_ptr();
|
||||
kv_len = std::make_shared<ov::op::v1::StridedSlice>(
|
||||
inp_pos, start, start, stride, std::vector<int64_t>{0, 0, 0}, std::vector<int64_t>{1, 1, 1});
|
||||
}
|
||||
kv_len = std::make_shared<ov::op::v0::Squeeze>(
|
||||
kv_len, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1}));
|
||||
kv_len = std::make_shared<ov::op::v0::Convert>(kv_len, ov::element::i64);
|
||||
kv_len = std::make_shared<ov::op::v1::Add>(kv_len, one_1d);
|
||||
auto stop = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{token_len, kv_len}, 0);
|
||||
|
||||
mask_sliced = std::make_shared<ov::op::v8::Slice>(mask, zero_2d, stop, one_2d, axes);
|
||||
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
||||
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||
auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
|
||||
mask_sliced = std::make_shared<ov::op::v8::Slice>(mask, zero, token_len_per_seq, one, two);
|
||||
mask_sliced = std::make_shared<ov::op::v0::Convert>(mask_sliced, ov::element::f16);
|
||||
mask_sliced->set_friendly_name(sliced_name);
|
||||
}
|
||||
|
|
@ -119,8 +94,7 @@ void add_sliced_mask(TensorMap & tensor_map, GgmlDecoder & ggml_model_decoder) {
|
|||
};
|
||||
|
||||
create_sliced_mask("KQ_mask", "KQ_mask_sliced", ggml_model_decoder.is_static());
|
||||
// swa is not working for the `kv_len` is not correct
|
||||
// create_sliced_mask("KQ_mask_swa", "KQ_mask_swa_sliced", ggml_model_decoder.is_static());
|
||||
create_sliced_mask("KQ_mask_swa", "KQ_mask_swa_sliced", ggml_model_decoder.is_static());
|
||||
}
|
||||
|
||||
void add_rope_sin_cos(TensorMap & tensor_map, GgmlDecoder & ggml_model_decoder) {
|
||||
|
|
@ -143,7 +117,6 @@ void add_rope_sin_cos(TensorMap & tensor_map, GgmlDecoder & ggml_model_decoder)
|
|||
|
||||
// Create common patterns
|
||||
void preprocess(TensorMap & tensor_map, GgmlDecoder & ggml_model_decoder) {
|
||||
add_token_len(tensor_map);
|
||||
add_sliced_mask(tensor_map, ggml_model_decoder);
|
||||
add_rope_sin_cos(tensor_map, ggml_model_decoder);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -77,12 +77,12 @@ ov::Output<ov::Node> rope_yarn_ramp_mix(int n_dims, const float corr_dims[2], fl
|
|||
int half_n_dims = n_dims / 2;
|
||||
std::vector<float> dim_ids_vec(half_n_dims);
|
||||
std::iota(dim_ids_vec.begin(), dim_ids_vec.end(), 0);
|
||||
auto dim_ids = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, (size_t) half_n_dims}, dim_ids_vec);
|
||||
auto corr_low = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1}, {corr_dims[0]});
|
||||
auto corr_high = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1}, {corr_dims[1]});
|
||||
auto denom =
|
||||
std::make_shared<ov::op::v1::Maximum>(std::make_shared<ov::op::v1::Subtract>(corr_high, corr_low),
|
||||
ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1}, {0.001f}));
|
||||
auto dim_ids = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, (size_t) half_n_dims}, dim_ids_vec);
|
||||
auto corr_low = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {corr_dims[0]});
|
||||
auto corr_high = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {corr_dims[1]});
|
||||
auto denom = std::make_shared<ov::op::v1::Maximum>(
|
||||
std::make_shared<ov::op::v1::Subtract>(corr_high, corr_low),
|
||||
ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {0.001f}));
|
||||
auto ramp_y =
|
||||
std::make_shared<ov::op::v1::Divide>(std::make_shared<ov::op::v1::Subtract>(dim_ids, corr_low), denom);
|
||||
auto ramp_clamped = std::make_shared<ov::op::v0::Clamp>(ramp_y, 0.0f, 1.0f);
|
||||
|
|
@ -116,7 +116,7 @@ std::pair<ov::Output<Node>, ov::Output<Node>> make_sin_cos(int32_t * rope_params
|
|||
std::shared_ptr<ov::Node> rope_freqs_weight) {
|
||||
inp_pos = std::make_shared<ov::op::v0::Convert>(inp_pos, ov::element::f32);
|
||||
auto pos_perm =
|
||||
std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{3}, std::vector<int64_t>{2, 1, 0});
|
||||
std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{4}, std::vector<int64_t>{0, 3, 1, 2});
|
||||
inp_pos = std::make_shared<ov::op::v1::Transpose>(inp_pos, pos_perm);
|
||||
|
||||
float freq_base;
|
||||
|
|
@ -146,7 +146,7 @@ std::pair<ov::Output<Node>, ov::Output<Node>> make_sin_cos(int32_t * rope_params
|
|||
}
|
||||
|
||||
Output<Node> freq_factors =
|
||||
std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{1, 1, factor.size()}, factor);
|
||||
std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{1, 1, 1, factor.size()}, factor);
|
||||
if (rope_freqs_weight) {
|
||||
freq_factors = std::make_shared<ov::op::v1::Divide>(freq_factors, rope_freqs_weight);
|
||||
}
|
||||
|
|
@ -161,7 +161,7 @@ std::pair<ov::Output<Node>, ov::Output<Node>> make_sin_cos(int32_t * rope_params
|
|||
theta = theta_interp;
|
||||
} else {
|
||||
auto ramp_mix = rope_yarn_ramp_mix(n_dims, corr_dims, ext_factor);
|
||||
auto one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1}, {1.0f});
|
||||
auto one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {1.0f});
|
||||
auto one_minus_ramp = std::make_shared<ov::op::v1::Subtract>(one, ramp_mix);
|
||||
|
||||
theta = std::make_shared<ov::op::v1::Add>(std::make_shared<ov::op::v1::Multiply>(theta_interp, one_minus_ramp),
|
||||
|
|
@ -183,19 +183,19 @@ ov::Output<ov::Node> process_view_input(const NodeContext & context, int input_i
|
|||
// Only works for VIEW operations that slice at the lowest dimension
|
||||
// If the VIEW also reshape the result, `slice_len` should be provided
|
||||
auto input = context.get_input(input_index);
|
||||
int32_t * op_params = context.get_input_op_params(input_index);
|
||||
auto * op_params = (size_t *) context.get_input_op_params(input_index);
|
||||
auto src1_stride = context.get_input_stride(input_index);
|
||||
|
||||
int64_t split_addr = op_params[0] / src1_stride[2];
|
||||
int64_t split_addr = op_params[0] / src1_stride[3];
|
||||
if (slice_len == 0) {
|
||||
slice_len = context.get_input_shape(input_index)[2].get_length();
|
||||
slice_len = context.get_input_shape(input_index)[3].get_length();
|
||||
}
|
||||
int64_t slice_end = split_addr + slice_len;
|
||||
|
||||
auto begin = ov::op::v0::Constant::create(ov::element::i64, {1}, {split_addr});
|
||||
auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {slice_end});
|
||||
auto stride = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
|
||||
auto axes = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
|
||||
auto axes = ov::op::v0::Constant::create(ov::element::i64, {1}, {3});
|
||||
auto sliced = std::make_shared<ov::op::v8::Slice>(input, begin, end, stride, axes);
|
||||
return sliced;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -129,18 +129,24 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, ggml_cgraph *
|
|||
ov_input_names_cache[cgraph] = ov_input_names;
|
||||
ov_output_names_cache[cgraph] = ov_output_names;
|
||||
|
||||
// Set output tensors and kvcache address for NPU once and for all since the graph is static
|
||||
if (is_static) {
|
||||
for (size_t i = 0; i < ov_output_names.size(); i++) {
|
||||
// Set output tensors (for NPU) and kvcache i/o tensors once and for all
|
||||
for (size_t i = 0; i < ov_output_names.size(); i++) {
|
||||
auto output_name = ov_output_names[i];
|
||||
if (is_static || output_name.find("cache") == 0) {
|
||||
auto output_tensor = get_ov_output_tensor(ggml_decoder, ov_output_names[i]);
|
||||
infer_request->set_output_tensor(i, output_tensor);
|
||||
}
|
||||
for (size_t i = 0; i < ov_input_names.size(); i++) {
|
||||
auto param_name = ov_input_names[i];
|
||||
if (param_name.find("cache_k") == 0 || param_name.find("cache_v") == 0) {
|
||||
auto input_tensor = get_ov_input_tensor_static(ggml_decoder, param_name, 0, 0);
|
||||
infer_request->set_input_tensor(i, input_tensor);
|
||||
}
|
||||
for (size_t i = 0; i < ov_input_names.size(); i++) {
|
||||
auto param_name = ov_input_names[i];
|
||||
if (param_name.find("cache") == 0) {
|
||||
ov::Tensor input_tensor;
|
||||
if (is_static) {
|
||||
input_tensor = get_ov_input_tensor_static(ggml_decoder, param_name, 0, 0);
|
||||
} else {
|
||||
input_tensor = get_ov_input_tensor(ggml_decoder, param_name);
|
||||
}
|
||||
infer_request->set_input_tensor(i, input_tensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -152,6 +158,9 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, ggml_cgraph *
|
|||
if (!is_static) {
|
||||
for (size_t i = 0; i < ov_input_names.size(); i++) {
|
||||
auto param_name = ov_input_names[i];
|
||||
if (param_name.find("cache") == 0) {
|
||||
continue;
|
||||
}
|
||||
auto input_tensor = get_ov_input_tensor(ggml_decoder, param_name);
|
||||
infer_request->set_input_tensor(i, input_tensor);
|
||||
|
||||
|
|
@ -179,7 +188,7 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, ggml_cgraph *
|
|||
for (int j = 0; j < input_len; j++) {
|
||||
for (size_t i = 0; i < ov_input_names.size(); i++) {
|
||||
auto param_name = ov_input_names[i];
|
||||
if (param_name.find("cache_k") == 0 || param_name.find("cache_v") == 0) {
|
||||
if (param_name.find("cache") == 0) {
|
||||
continue;
|
||||
}
|
||||
auto input_tensor = get_ov_input_tensor_static(ggml_decoder, param_name, j, input_len);
|
||||
|
|
@ -306,7 +315,7 @@ ov::Tensor convert_ggml_input_to_ov(std::shared_ptr<GgmlOvDecoder> ggml_decoder,
|
|||
ov::Shape input_shape;
|
||||
if (ggml_tensor->op == GGML_OP_VIEW) {
|
||||
// This case is added to make test-backend-ops work
|
||||
input_shape = ggml_decoder->get_graph_input_shape(ggml_tensor->view_src).to_shape();
|
||||
input_shape = ggml_decoder->get_graph_input_shape(ggml_tensor, ggml_tensor->view_src).to_shape();
|
||||
} else {
|
||||
input_shape = ggml_decoder->get_input_shape(name).to_shape();
|
||||
}
|
||||
|
|
@ -319,13 +328,6 @@ ov::Tensor get_ov_input_tensor(std::shared_ptr<GgmlOvDecoder> ggml_decoder, cons
|
|||
ov::Tensor input_tensor;
|
||||
if (ggml_decoder->get_model_extra_inputs().find(param_name) != ggml_decoder->get_model_extra_inputs().end()) {
|
||||
input_tensor = *ggml_decoder->get_model_extra_input_values().at(param_name);
|
||||
|
||||
} else if (param_name.find("cache_k") == 0 || param_name.find("cache_v") == 0) {
|
||||
void * input_data = ggml_decoder->get_input_ggml_tensor(param_name)->data;
|
||||
ov::Shape input_shape = {(size_t) ggml_decoder->get_past_kv_len(), (size_t) ggml_decoder->get_num_heads_kv(),
|
||||
(size_t) ggml_decoder->get_head_size()};
|
||||
input_tensor = ov::Tensor(ggml_decoder->get_input_type(param_name), input_shape, input_data);
|
||||
|
||||
} else {
|
||||
input_tensor = convert_ggml_input_to_ov(ggml_decoder, param_name);
|
||||
}
|
||||
|
|
@ -339,15 +341,8 @@ ov::Tensor get_ov_input_tensor_static(std::shared_ptr<GgmlOvDecoder> ggml_decode
|
|||
const auto * ggml_tensor = ggml_decoder->get_input_ggml_tensor(param_name);
|
||||
const auto * op = ggml_decoder->get_tensor_used_op(ggml_tensor);
|
||||
|
||||
if (param_name.find("cache_k") == 0 || param_name.find("cache_v") == 0) {
|
||||
void * input_data = ggml_decoder->get_input_ggml_tensor(param_name)->data;
|
||||
ov::Shape input_shape = {(size_t) ggml_decoder->get_context_size(), (size_t) ggml_decoder->get_num_heads_kv(),
|
||||
(size_t) ggml_decoder->get_head_size()};
|
||||
return ov::Tensor(ggml_decoder->get_input_type(param_name), input_shape, input_data);
|
||||
}
|
||||
|
||||
if (param_name == "inp_pos" || param_name == "inp_tokens" || op->op == GGML_OP_SET_ROWS) {
|
||||
ov::Shape input_shape = {1, 1, 1};
|
||||
ov::Shape input_shape = {1, 1, 1, 1};
|
||||
ov::Tensor input_tensor(ggml_decoder->get_input_type(param_name), input_shape);
|
||||
// copy the j-th value from ggml_tensor
|
||||
size_t element_size = ggml_type_size(ggml_tensor->type);
|
||||
|
|
@ -357,7 +352,7 @@ ov::Tensor get_ov_input_tensor_static(std::shared_ptr<GgmlOvDecoder> ggml_decode
|
|||
}
|
||||
|
||||
if (param_name == "inp_out_ids") {
|
||||
ov::Shape input_shape = {1, 1, 1};
|
||||
ov::Shape input_shape = {1, 1, 1, 1};
|
||||
ov::Tensor input_tensor(ggml_decoder->get_input_type(param_name), input_shape);
|
||||
if (ggml_tensor->ne[0] == 0) {
|
||||
*input_tensor.data<int32_t>() = 0;
|
||||
|
|
@ -374,10 +369,10 @@ ov::Tensor get_ov_input_tensor_static(std::shared_ptr<GgmlOvDecoder> ggml_decode
|
|||
}
|
||||
|
||||
if (param_name.find("KQ_mask") == 0) {
|
||||
size_t context_size = ggml_decoder->get_context_size();
|
||||
size_t context_size = ggml_decoder->get_ctx_size();
|
||||
const auto * input_tensor_ggml = ggml_decoder->get_input_ggml_tensor(param_name);
|
||||
std::vector<float> padded_data = pad_input<float>(input_tensor_ggml, input_len, context_size, -INFINITY);
|
||||
ov::Tensor input_tensor(ov::element::f32, ov::Shape{1, 1, context_size});
|
||||
ov::Tensor input_tensor(ov::element::f32, ov::Shape{1, 1, 1, context_size});
|
||||
// copy the j-th row of padded_data
|
||||
auto * data_ptr = input_tensor.data<float>();
|
||||
std::copy(padded_data.begin() + j * context_size, padded_data.begin() + (j + 1) * context_size, data_ptr);
|
||||
|
|
@ -391,20 +386,12 @@ ov::Tensor get_ov_output_tensor(std::shared_ptr<GgmlOvDecoder> ggml_decoder, con
|
|||
auto * ggml_tensor = ggml_decoder->get_output_ggml_tensor(result_name);
|
||||
auto output_type = ggml_decoder->get_output_type(result_name);
|
||||
ov::Shape output_shape;
|
||||
if (result_name.find("cache") == std::string::npos) {
|
||||
output_shape = ggml_decoder->get_output_shape(result_name).to_shape();
|
||||
if (ggml_decoder->is_static() && result_name == "result_output") {
|
||||
output_shape[1] = 1;
|
||||
}
|
||||
} else {
|
||||
size_t total_token_len = ggml_decoder->get_past_kv_len() + ggml_decoder->get_input_len();
|
||||
size_t num_heads_kv = ggml_decoder->get_num_heads_kv();
|
||||
size_t head_size = ggml_decoder->get_head_size();
|
||||
if (ggml_decoder->is_static()) {
|
||||
total_token_len = ggml_decoder->get_context_size();
|
||||
}
|
||||
output_shape = ov::Shape{total_token_len, num_heads_kv, head_size};
|
||||
output_shape = ggml_decoder->get_output_shape(result_name).to_shape();
|
||||
|
||||
if (ggml_decoder->is_static() && result_name == "result_output") {
|
||||
output_shape[1] = 1;
|
||||
}
|
||||
|
||||
ov::Tensor output_tensor(output_type, output_shape, ggml_tensor->data);
|
||||
return output_tensor;
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue