FIX: set_max_token_len
This commit is contained in:
parent
a30dc6e726
commit
8ac5c225aa
|
|
@ -44,13 +44,14 @@ GgmlOvDecoder::GgmlOvDecoder(struct ggml_tensor* node, struct ggml_cgraph* cgrap
|
|||
dump_cgraph(m_cgraph);
|
||||
}
|
||||
|
||||
set_max_token_len();
|
||||
|
||||
static bool weight_created = false;
|
||||
if (!getenv("GGML_OPENVINO_WEIGHT_AS_INPUT") && !weight_created) {
|
||||
add_weight_const_parallel(model_weights);
|
||||
weight_created = true;
|
||||
}
|
||||
|
||||
set_max_token_len();
|
||||
for (int node_n = 0; node_n < m_cgraph->n_nodes; node_n++) {
|
||||
auto* cur_node = m_cgraph->nodes[node_n];
|
||||
m_nodes.push_back(cur_node);
|
||||
|
|
@ -197,7 +198,7 @@ void GgmlOvDecoder::set_max_token_len() {
|
|||
auto* node = m_cgraph->nodes[i];
|
||||
if (std::string(node->name) == "k-0") {
|
||||
auto* cache_k = node->src[0];
|
||||
m_max_token_len = cache_k->ne[0] / node->ne[0] / node->ne[1];
|
||||
m_max_token_len = cache_k->ne[0] / node->ne[0] / node->ne[2];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -209,4 +209,4 @@ void print_output_tensor_info(const std::string& name,
|
|||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue