FIX: input shape of KQ_mask

This commit is contained in:
Yu, Zijun 2025-05-14 17:48:20 +08:00 committed by Mustafa Cavus
parent 041d220dfa
commit c57f61494a
1 changed files with 5 additions and 5 deletions

View File

@ -112,8 +112,8 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node,
if (std::string(src->name) == "inp_tokens" || std::string(src->name) == "inp_pos") {
input_shape = ov::PartialShape{1, 1, ov::Dimension(1, m_max_token_len)};
} else if (std::string(src->name).find("KQ_mask") == 0) {
input_shape =
ov::PartialShape{1, ov::Dimension(1, m_max_token_len), ov::Dimension(1, m_max_token_len)};
auto max_token_len = GGML_PAD(m_max_token_len, GGML_KQ_MASK_PAD);
input_shape = ov::PartialShape{1, ov::Dimension(1, max_token_len), ov::Dimension(1, max_token_len)};
} else {
input_shape = ov::Shape{get_shape(src)};
}
@ -187,9 +187,9 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node,
void GgmlOvDecoder::set_max_token_len() {
for (int i = 0; i < m_cgraph->n_nodes; i++) {
auto* node = m_cgraph->nodes[i];
if (std::string(node->name) == "v-0") {
auto* cache_v = node->src[0];
m_max_token_len = cache_v->ne[0] / node->ne[1] / node->ne[2];
if (std::string(node->name) == "k-0") {
auto* cache_k = node->src[0];
m_max_token_len = cache_k->ne[0] / node->ne[0] / node->ne[1];
break;
}
}