diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index 66f82773e3..2a95c894f4 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -187,6 +187,16 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node) { case GGML_OP_MUL_MAT: { if (node->src[0]->view_src == nullptr) { m_op_case = 1; + } else if (std::string(node->src[0]->name).find("cache_k") == 0) { + m_op_case = 2; + } else if (std::string(node->src[0]->name).find("cache_v") == 0) { + m_op_case = 3; + } + break; + } + case GGML_OP_PERMUTE: { + if (ggml_is_contiguous(node->src[0])) { + m_op_case = 1; } else { m_op_case = 2; } diff --git a/ggml/src/ggml-openvino/openvino/op/mulmat.cpp b/ggml/src/ggml-openvino/openvino/op/mulmat.cpp index 0d3190f6c1..728ee5cb5f 100644 --- a/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +++ b/ggml/src/ggml-openvino/openvino/op/mulmat.cpp @@ -24,7 +24,7 @@ OutputVector translate_mulmat(const NodeContext& context) { num_inputs_check(context, 2, 2); int op_case = context.get_op_case(); - FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2, "Unsupported MULMAT case"); + FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2 || op_case == 3, "Unsupported MULMAT case"); ov::Output res; @@ -59,8 +59,7 @@ OutputVector translate_mulmat(const NodeContext& context) { auto src0 = context.get_input(0); auto src0_shape = context.get_input_shape(0).to_shape(); auto src0_stride = context.get_input_stride(0); - auto permuted = is_permuted(src0_stride); - auto token_dim = permuted ? 0 : 2; + auto token_dim = op_case == 2 ? 0 : 2; auto attention_size = context.get_input("attention_size"); @@ -81,7 +80,7 @@ OutputVector translate_mulmat(const NodeContext& context) { auto src0_reshape = std::make_shared(src0, src0_reshape_shape, false); std::shared_ptr slice_end; - if (permuted) { + if (op_case == 2) { slice_end = std::make_shared( ov::OutputVector{attention_size, ov::op::v0::Constant::create(ov::element::i64, {2}, src0_slice_shape)}, 0); @@ -94,7 +93,7 @@ OutputVector translate_mulmat(const NodeContext& context) { auto slice_step = ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector(3, 1)); auto src0_slice = std::make_shared(src0_reshape, slice_start, slice_end, slice_step); - if (permuted) { + if (op_case == 2) { B = std::make_shared( src0_slice, ov::op::v0::Constant::create(ov::element::i64, {src0_perm.size()}, src0_perm)); diff --git a/ggml/src/ggml-openvino/openvino/op/permute.cpp b/ggml/src/ggml-openvino/openvino/op/permute.cpp index 649cf8f3e1..8e91b61201 100644 --- a/ggml/src/ggml-openvino/openvino/op/permute.cpp +++ b/ggml/src/ggml-openvino/openvino/op/permute.cpp @@ -12,10 +12,19 @@ namespace op { OutputVector translate_permute(const NodeContext& context) { num_inputs_check(context, 1, 1); - auto perm = argsort_descend(context.get_output_stride(0)); - auto res = std::make_shared(context.get_input(0), - ov::op::v0::Constant::create(ov::element::i64, {3}, perm)); - return rename_outputs_with_suffix({res}, context.get_name()); + int op_case = context.get_op_case(); + FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2, "Unsupported CONT case"); + ov::Output res; + + if (op_case == 1) { + auto perm = argsort_descend(context.get_output_stride(0)); + auto res = std::make_shared(context.get_input(0), + ov::op::v0::Constant::create(ov::element::i64, {3}, perm)); + return rename_outputs_with_suffix({res}, context.get_name()); + } else { + auto res = context.get_input(0); + return {res}; + } } } // namespace op