Change due to ggml cgraph changes, not correct yet

This commit is contained in:
Yu, Zijun 2025-06-04 17:22:50 +08:00 committed by Mustafa Cavus
parent d9ca8f5dbe
commit f7ad77930e
3 changed files with 27 additions and 9 deletions

View File

@ -187,6 +187,16 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node) {
case GGML_OP_MUL_MAT: {
if (node->src[0]->view_src == nullptr) {
m_op_case = 1;
} else if (std::string(node->src[0]->name).find("cache_k") == 0) {
m_op_case = 2;
} else if (std::string(node->src[0]->name).find("cache_v") == 0) {
m_op_case = 3;
}
break;
}
case GGML_OP_PERMUTE: {
if (ggml_is_contiguous(node->src[0])) {
m_op_case = 1;
} else {
m_op_case = 2;
}

View File

@ -24,7 +24,7 @@ OutputVector translate_mulmat(const NodeContext& context) {
num_inputs_check(context, 2, 2);
int op_case = context.get_op_case();
FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2, "Unsupported MULMAT case");
FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2 || op_case == 3, "Unsupported MULMAT case");
ov::Output<Node> res;
@ -59,8 +59,7 @@ OutputVector translate_mulmat(const NodeContext& context) {
auto src0 = context.get_input(0);
auto src0_shape = context.get_input_shape(0).to_shape();
auto src0_stride = context.get_input_stride(0);
auto permuted = is_permuted(src0_stride);
auto token_dim = permuted ? 0 : 2;
auto token_dim = op_case == 2 ? 0 : 2;
auto attention_size = context.get_input("attention_size");
@ -81,7 +80,7 @@ OutputVector translate_mulmat(const NodeContext& context) {
auto src0_reshape = std::make_shared<ov::op::v1::Reshape>(src0, src0_reshape_shape, false);
std::shared_ptr<ov::Node> slice_end;
if (permuted) {
if (op_case == 2) {
slice_end = std::make_shared<ov::op::v0::Concat>(
ov::OutputVector{attention_size, ov::op::v0::Constant::create(ov::element::i64, {2}, src0_slice_shape)},
0);
@ -94,7 +93,7 @@ OutputVector translate_mulmat(const NodeContext& context) {
auto slice_step = ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>(3, 1));
auto src0_slice = std::make_shared<ov::op::v8::Slice>(src0_reshape, slice_start, slice_end, slice_step);
if (permuted) {
if (op_case == 2) {
B = std::make_shared<ov::op::v1::Transpose>(
src0_slice,
ov::op::v0::Constant::create(ov::element::i64, {src0_perm.size()}, src0_perm));

View File

@ -12,10 +12,19 @@ namespace op {
OutputVector translate_permute(const NodeContext& context) {
num_inputs_check(context, 1, 1);
auto perm = argsort_descend(context.get_output_stride(0));
auto res = std::make_shared<ov::op::v1::Transpose>(context.get_input(0),
ov::op::v0::Constant::create(ov::element::i64, {3}, perm));
return rename_outputs_with_suffix({res}, context.get_name());
int op_case = context.get_op_case();
FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2, "Unsupported CONT case");
ov::Output<Node> res;
if (op_case == 1) {
auto perm = argsort_descend(context.get_output_stride(0));
auto res = std::make_shared<ov::op::v1::Transpose>(context.get_input(0),
ov::op::v0::Constant::create(ov::element::i64, {3}, perm));
return rename_outputs_with_suffix({res}, context.get_name());
} else {
auto res = context.get_input(0);
return {res};
}
}
} // namespace op