Fix after rebasing

This commit is contained in:
Yu, Zijun 2025-09-28 11:24:13 +08:00 committed by Mustafa Cavus
parent f3afa7b914
commit fdadca1e89
3 changed files with 26 additions and 6 deletions

View File

@ -198,13 +198,17 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) {
if (node->src[0]->op != GGML_OP_VIEW) {
m_op_case = 1;
} else if (ggml_is_contiguous(node->src[0])) {
// Permute kv cache (view)
std::string src_name(node->view_src->name);
int layer = extract_layer_from_name(src_name);
if (!is_swa_layer(layer)) {
m_op_case = 2;
if (src_name.find("cache") == std::string::npos) {
m_op_case = 1;
} else {
m_op_case = 3;
// Permute kv cache (view)
int layer = extract_layer_from_name(src_name);
if (!is_swa_layer(layer)) {
m_op_case = 2;
} else {
m_op_case = 3;
}
}
}
break;
@ -230,6 +234,16 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) {
}
break;
}
case GGML_OP_VIEW: {
if (node->src[0]->op == GGML_OP_VIEW) {
auto* src = node->src[0];
auto* view_src = src->view_src;
if (view_src->ne[1] != src->ne[2]) {
throw std::runtime_error("Unsupported VIEW case");
}
m_op_case = 2;
}
}
default:
break;
}

View File

@ -45,7 +45,9 @@ OutputVector translate_set_rows(const NodeContext& context) {
false);
auto indices_reshaped =
std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1}));
auto data_reshaped = std::make_shared<ov::op::v0::Squeeze>(data, zero);
auto data_reshaped = std::make_shared<ov::op::v1::Reshape>(
data, ov::op::v0::Constant::create(ov::element::i64, {2}, {(int64_t) -1, (int64_t) dst_shape[2]}), false);
auto updated = std::make_shared<ov::op::v3::ScatterUpdate>(dst_reshaped, indices_reshaped, data_reshaped, zero);
auto res = std::make_shared<ov::op::v1::Reshape>(updated, std::make_shared<ov::op::v0::ShapeOf>(dst), false);
return rename_outputs_with_suffix({res}, context.get_name());

View File

@ -9,6 +9,10 @@ namespace op {
OutputVector translate_view(const NodeContext& context) {
num_inputs_check(context, 1, 1);
if (context.get_op_case() == 2) {
auto dst_shape = context.get_output_shape(0).to_shape();
return rename_outputs_with_suffix({process_view_input(context, 0, dst_shape[1] * dst_shape[2])}, context.get_name());
}
return {context.get_input(0)};
}