Change due to ggml cgraph changes, not correct yet
This commit is contained in:
parent
d9ca8f5dbe
commit
f7ad77930e
|
|
@ -187,6 +187,16 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node) {
|
|||
case GGML_OP_MUL_MAT: {
|
||||
if (node->src[0]->view_src == nullptr) {
|
||||
m_op_case = 1;
|
||||
} else if (std::string(node->src[0]->name).find("cache_k") == 0) {
|
||||
m_op_case = 2;
|
||||
} else if (std::string(node->src[0]->name).find("cache_v") == 0) {
|
||||
m_op_case = 3;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case GGML_OP_PERMUTE: {
|
||||
if (ggml_is_contiguous(node->src[0])) {
|
||||
m_op_case = 1;
|
||||
} else {
|
||||
m_op_case = 2;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ OutputVector translate_mulmat(const NodeContext& context) {
|
|||
num_inputs_check(context, 2, 2);
|
||||
|
||||
int op_case = context.get_op_case();
|
||||
FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2, "Unsupported MULMAT case");
|
||||
FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2 || op_case == 3, "Unsupported MULMAT case");
|
||||
|
||||
ov::Output<Node> res;
|
||||
|
||||
|
|
@ -59,8 +59,7 @@ OutputVector translate_mulmat(const NodeContext& context) {
|
|||
auto src0 = context.get_input(0);
|
||||
auto src0_shape = context.get_input_shape(0).to_shape();
|
||||
auto src0_stride = context.get_input_stride(0);
|
||||
auto permuted = is_permuted(src0_stride);
|
||||
auto token_dim = permuted ? 0 : 2;
|
||||
auto token_dim = op_case == 2 ? 0 : 2;
|
||||
|
||||
auto attention_size = context.get_input("attention_size");
|
||||
|
||||
|
|
@ -81,7 +80,7 @@ OutputVector translate_mulmat(const NodeContext& context) {
|
|||
auto src0_reshape = std::make_shared<ov::op::v1::Reshape>(src0, src0_reshape_shape, false);
|
||||
|
||||
std::shared_ptr<ov::Node> slice_end;
|
||||
if (permuted) {
|
||||
if (op_case == 2) {
|
||||
slice_end = std::make_shared<ov::op::v0::Concat>(
|
||||
ov::OutputVector{attention_size, ov::op::v0::Constant::create(ov::element::i64, {2}, src0_slice_shape)},
|
||||
0);
|
||||
|
|
@ -94,7 +93,7 @@ OutputVector translate_mulmat(const NodeContext& context) {
|
|||
auto slice_step = ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector<int64_t>(3, 1));
|
||||
auto src0_slice = std::make_shared<ov::op::v8::Slice>(src0_reshape, slice_start, slice_end, slice_step);
|
||||
|
||||
if (permuted) {
|
||||
if (op_case == 2) {
|
||||
B = std::make_shared<ov::op::v1::Transpose>(
|
||||
src0_slice,
|
||||
ov::op::v0::Constant::create(ov::element::i64, {src0_perm.size()}, src0_perm));
|
||||
|
|
|
|||
|
|
@ -12,10 +12,19 @@ namespace op {
|
|||
OutputVector translate_permute(const NodeContext& context) {
|
||||
num_inputs_check(context, 1, 1);
|
||||
|
||||
auto perm = argsort_descend(context.get_output_stride(0));
|
||||
auto res = std::make_shared<ov::op::v1::Transpose>(context.get_input(0),
|
||||
ov::op::v0::Constant::create(ov::element::i64, {3}, perm));
|
||||
return rename_outputs_with_suffix({res}, context.get_name());
|
||||
int op_case = context.get_op_case();
|
||||
FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2, "Unsupported CONT case");
|
||||
ov::Output<Node> res;
|
||||
|
||||
if (op_case == 1) {
|
||||
auto perm = argsort_descend(context.get_output_stride(0));
|
||||
auto res = std::make_shared<ov::op::v1::Transpose>(context.get_input(0),
|
||||
ov::op::v0::Constant::create(ov::element::i64, {3}, perm));
|
||||
return rename_outputs_with_suffix({res}, context.get_name());
|
||||
} else {
|
||||
auto res = context.get_input(0);
|
||||
return {res};
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
|
|
|
|||
Loading…
Reference in New Issue