Fix llama-cli
This commit is contained in:
parent
ea75772e48
commit
1ed49bbfaf
|
|
@ -244,22 +244,36 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor* src) co
|
||||||
}
|
}
|
||||||
|
|
||||||
void GgmlOvDecoder::add_extra_inputs() {
|
void GgmlOvDecoder::add_extra_inputs() {
|
||||||
// attention_size not used for NPU
|
// Extra inputs:
|
||||||
|
// 1. `attention_size`, used in matmul's in the attention block. The shape of those matmul's are 32 aligned,
|
||||||
|
// see llama_kv_cache_unified::get_n_kv and llama_kv_cache_unified::get_padding.
|
||||||
|
// Not used for NPU
|
||||||
int64_t attention_size = -1;
|
int64_t attention_size = -1;
|
||||||
|
|
||||||
int64_t past_token_len = -1;
|
int64_t past_token_len = -1;
|
||||||
|
int64_t past_token_len_from_inp_pos = -1;
|
||||||
for (const auto& node : m_nodes) {
|
for (const auto& node : m_nodes) {
|
||||||
|
if (node->op == GGML_OP_ROPE && std::string(node->src[1]->name) == "inp_pos") {
|
||||||
|
if (node->src[1]->type != GGML_TYPE_I32) {
|
||||||
|
throw std::runtime_error("Expected cgraph input `inp_pos` to be of type GGML_TYPE_I32");
|
||||||
|
}
|
||||||
|
past_token_len_from_inp_pos = ((int32_t*) (node->src[1]->data))[0];
|
||||||
|
}
|
||||||
if (node->op == GGML_OP_CPY && ggml_is_contiguous(node)) {
|
if (node->op == GGML_OP_CPY && ggml_is_contiguous(node)) {
|
||||||
assert(std::string(node->view_src->name).find("cache_k") == 0);
|
assert(std::string(node->view_src->name).find("cache_k") == 0);
|
||||||
int64_t head_size = node->src[0]->ne[0];
|
past_token_len =
|
||||||
int64_t num_heads = node->src[0]->ne[1];
|
(int64_t) (node->src[1]->op_params[0] / node->src[1]->nb[0] / m_head_size / m_num_heads_kv);
|
||||||
past_token_len = (int64_t) (node->src[1]->op_params[0] / node->src[1]->nb[0] / head_size / num_heads);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (past_token_len == -1) {
|
if (past_token_len == -1) {
|
||||||
throw std::runtime_error("Failed to find input \"cache_k\" in the graph");
|
throw std::runtime_error("Failed to find input \"cache_k\" in the graph");
|
||||||
}
|
}
|
||||||
|
if (past_token_len != past_token_len_from_inp_pos) {
|
||||||
|
throw std::runtime_error("Mismatch between past_token_len from cache_k and inp_pos: " +
|
||||||
|
std::to_string(past_token_len) + " vs " + std::to_string(past_token_len_from_inp_pos));
|
||||||
|
}
|
||||||
|
|
||||||
for (const auto& node : m_nodes) {
|
for (const auto& node : m_nodes) {
|
||||||
if (node->src[1] && std::string(node->src[1]->name).find("inp_tokens") == 0) {
|
if (node->src[1] && std::string(node->src[1]->name).find("inp_tokens") == 0) {
|
||||||
int64_t total_token_len = node->src[1]->ne[0] + past_token_len;
|
int64_t total_token_len = node->src[1]->ne[0] + past_token_len;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue