FIX: input shape of KQ_mask
This commit is contained in:
parent
041d220dfa
commit
c57f61494a
|
|
@ -112,8 +112,8 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node,
|
|||
if (std::string(src->name) == "inp_tokens" || std::string(src->name) == "inp_pos") {
|
||||
input_shape = ov::PartialShape{1, 1, ov::Dimension(1, m_max_token_len)};
|
||||
} else if (std::string(src->name).find("KQ_mask") == 0) {
|
||||
input_shape =
|
||||
ov::PartialShape{1, ov::Dimension(1, m_max_token_len), ov::Dimension(1, m_max_token_len)};
|
||||
auto max_token_len = GGML_PAD(m_max_token_len, GGML_KQ_MASK_PAD);
|
||||
input_shape = ov::PartialShape{1, ov::Dimension(1, max_token_len), ov::Dimension(1, max_token_len)};
|
||||
} else {
|
||||
input_shape = ov::Shape{get_shape(src)};
|
||||
}
|
||||
|
|
@ -187,9 +187,9 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node,
|
|||
void GgmlOvDecoder::set_max_token_len() {
|
||||
for (int i = 0; i < m_cgraph->n_nodes; i++) {
|
||||
auto* node = m_cgraph->nodes[i];
|
||||
if (std::string(node->name) == "v-0") {
|
||||
auto* cache_v = node->src[0];
|
||||
m_max_token_len = cache_v->ne[0] / node->ne[1] / node->ne[2];
|
||||
if (std::string(node->name) == "k-0") {
|
||||
auto* cache_k = node->src[0];
|
||||
m_max_token_len = cache_k->ne[0] / node->ne[0] / node->ne[1];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue