Fix llama-cli
This commit is contained in:
parent
ea75772e48
commit
1ed49bbfaf
|
|
@ -244,22 +244,36 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor* src) co
|
|||
}
|
||||
|
||||
void GgmlOvDecoder::add_extra_inputs() {
|
||||
// attention_size not used for NPU
|
||||
// Extra inputs:
|
||||
// 1. `attention_size`, used in matmul's in the attention block. The shape of those matmul's are 32 aligned,
|
||||
// see llama_kv_cache_unified::get_n_kv and llama_kv_cache_unified::get_padding.
|
||||
// Not used for NPU
|
||||
int64_t attention_size = -1;
|
||||
|
||||
int64_t past_token_len = -1;
|
||||
int64_t past_token_len_from_inp_pos = -1;
|
||||
for (const auto& node : m_nodes) {
|
||||
if (node->op == GGML_OP_ROPE && std::string(node->src[1]->name) == "inp_pos") {
|
||||
if (node->src[1]->type != GGML_TYPE_I32) {
|
||||
throw std::runtime_error("Expected cgraph input `inp_pos` to be of type GGML_TYPE_I32");
|
||||
}
|
||||
past_token_len_from_inp_pos = ((int32_t*) (node->src[1]->data))[0];
|
||||
}
|
||||
if (node->op == GGML_OP_CPY && ggml_is_contiguous(node)) {
|
||||
assert(std::string(node->view_src->name).find("cache_k") == 0);
|
||||
int64_t head_size = node->src[0]->ne[0];
|
||||
int64_t num_heads = node->src[0]->ne[1];
|
||||
past_token_len = (int64_t) (node->src[1]->op_params[0] / node->src[1]->nb[0] / head_size / num_heads);
|
||||
past_token_len =
|
||||
(int64_t) (node->src[1]->op_params[0] / node->src[1]->nb[0] / m_head_size / m_num_heads_kv);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (past_token_len == -1) {
|
||||
throw std::runtime_error("Failed to find input \"cache_k\" in the graph");
|
||||
}
|
||||
if (past_token_len != past_token_len_from_inp_pos) {
|
||||
throw std::runtime_error("Mismatch between past_token_len from cache_k and inp_pos: " +
|
||||
std::to_string(past_token_len) + " vs " + std::to_string(past_token_len_from_inp_pos));
|
||||
}
|
||||
|
||||
for (const auto& node : m_nodes) {
|
||||
if (node->src[1] && std::string(node->src[1]->name).find("inp_tokens") == 0) {
|
||||
int64_t total_token_len = node->src[1]->ne[0] + past_token_len;
|
||||
|
|
|
|||
Loading…
Reference in New Issue