Fix after rebasing
This commit is contained in:
parent
f3afa7b914
commit
fdadca1e89
|
|
@ -198,13 +198,17 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) {
|
|||
if (node->src[0]->op != GGML_OP_VIEW) {
|
||||
m_op_case = 1;
|
||||
} else if (ggml_is_contiguous(node->src[0])) {
|
||||
// Permute kv cache (view)
|
||||
std::string src_name(node->view_src->name);
|
||||
int layer = extract_layer_from_name(src_name);
|
||||
if (!is_swa_layer(layer)) {
|
||||
m_op_case = 2;
|
||||
if (src_name.find("cache") == std::string::npos) {
|
||||
m_op_case = 1;
|
||||
} else {
|
||||
m_op_case = 3;
|
||||
// Permute kv cache (view)
|
||||
int layer = extract_layer_from_name(src_name);
|
||||
if (!is_swa_layer(layer)) {
|
||||
m_op_case = 2;
|
||||
} else {
|
||||
m_op_case = 3;
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
|
@ -230,6 +234,16 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) {
|
|||
}
|
||||
break;
|
||||
}
|
||||
case GGML_OP_VIEW: {
|
||||
if (node->src[0]->op == GGML_OP_VIEW) {
|
||||
auto* src = node->src[0];
|
||||
auto* view_src = src->view_src;
|
||||
if (view_src->ne[1] != src->ne[2]) {
|
||||
throw std::runtime_error("Unsupported VIEW case");
|
||||
}
|
||||
m_op_case = 2;
|
||||
}
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -45,7 +45,9 @@ OutputVector translate_set_rows(const NodeContext& context) {
|
|||
false);
|
||||
auto indices_reshaped =
|
||||
std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1}));
|
||||
auto data_reshaped = std::make_shared<ov::op::v0::Squeeze>(data, zero);
|
||||
auto data_reshaped = std::make_shared<ov::op::v1::Reshape>(
|
||||
data, ov::op::v0::Constant::create(ov::element::i64, {2}, {(int64_t) -1, (int64_t) dst_shape[2]}), false);
|
||||
|
||||
auto updated = std::make_shared<ov::op::v3::ScatterUpdate>(dst_reshaped, indices_reshaped, data_reshaped, zero);
|
||||
auto res = std::make_shared<ov::op::v1::Reshape>(updated, std::make_shared<ov::op::v0::ShapeOf>(dst), false);
|
||||
return rename_outputs_with_suffix({res}, context.get_name());
|
||||
|
|
|
|||
|
|
@ -9,6 +9,10 @@ namespace op {
|
|||
OutputVector translate_view(const NodeContext& context) {
|
||||
num_inputs_check(context, 1, 1);
|
||||
|
||||
if (context.get_op_case() == 2) {
|
||||
auto dst_shape = context.get_output_shape(0).to_shape();
|
||||
return rename_outputs_with_suffix({process_view_input(context, 0, dst_shape[1] * dst_shape[2])}, context.get_name());
|
||||
}
|
||||
return {context.get_input(0)};
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue