From 839f8c66a0f69bca54c3f067a73dcb870daf70bf Mon Sep 17 00:00:00 2001 From: "Yu, Zijun" Date: Thu, 14 Aug 2025 16:00:38 +0800 Subject: [PATCH] Remove CPY --- ggml/src/ggml-openvino/ggml-decoder.cpp | 71 +++--------------- ggml/src/ggml-openvino/ggml-openvino.cpp | 19 ++++- ggml/src/ggml-openvino/openvino/op/cpy.cpp | 73 ------------------- ggml/src/ggml-openvino/openvino/op_table.cpp | 1 - ggml/src/ggml-openvino/openvino/op_table.hpp | 1 - .../openvino/translate_session.cpp | 60 --------------- 6 files changed, 25 insertions(+), 200 deletions(-) delete mode 100644 ggml/src/ggml-openvino/openvino/op/cpy.cpp diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp index 472dd157ef..38c7122f4c 100644 --- a/ggml/src/ggml-openvino/ggml-decoder.cpp +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -90,10 +90,10 @@ GgmlOvDecoder::GgmlOvDecoder(struct ggml_cgraph* cgraph) { // 3. constructing a decoder for the whole graph naively (op test case) void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) { std::string node_name; - if (node->op == GGML_OP_CPY || node->op == GGML_OP_SET_ROWS) { - // CPY updates the input tensor in place. For later ov op that uses the - // input tensor of CPY, we need to make sure they get the updated tensor - // by putting the src tensor name in the tensor_map in + if (node->op == GGML_OP_SET_ROWS) { + // SET_ROWS updates the tensor in place. For later ov op that uses the + // the view_src of SET_ROWS, we need to make sure they get the updated tensor + // by putting the view_src name in the tensor_map in // /src/frontends/ggml/src/translate_session.cpp node_name = std::string(node->view_src->name); } else { @@ -183,16 +183,6 @@ void GgmlOvDecoder::set_input_output(ggml_tensor* node, bool naive) { } break; } - case GGML_OP_CPY: { - if (std::string(node->src[1]->name).find("cache_k") == 0) { - // Write K to cache_k - m_op_case = 1; - } else { - // Write V to cache_v - m_op_case = 2; - } - break; - } case GGML_OP_SET_ROWS: { if (std::string(node->name).find("cache_k") == 0) { m_op_case = 1; @@ -305,62 +295,22 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor* src) co void GgmlOvDecoder::add_extra_inputs() { // Extra inputs: - // 1. `past_token_len`, used to create indices for updating kv cache. Usually equal to inp_pos[0], except for - // llama-perplexity. - // Update: SET_ROWS replaces CPY for updating kv cache. The indices creation is not needed anymore. See: - // https://github.com/ggml-org/llama.cpp/pull/14285 - // 2. `attention_size`, used in matmul's in the attention block. The shape of those matmul's are 32 aligned, + // 1. `attention_size`, used in matmul's in the attention block. The shape of those matmul's are 32 aligned, // see llama_kv_cache_unified::get_n_kv and llama_kv_cache_unified::get_padding. // Not used for NPU - int64_t past_token_len = -1; int64_t attention_size = -1; - - int64_t token_len = -1; - int64_t past_token_len_from_inp_pos = -1; for (const auto& node : m_nodes) { - if (node->op == GGML_OP_ROPE && std::string(node->src[1]->name) == "inp_pos") { - if (node->src[1]->type != GGML_TYPE_I32) { - throw std::runtime_error("Expected cgraph input `inp_pos` to be of type GGML_TYPE_I32"); + if (node->op == GGML_OP_SOFT_MAX) { + auto* mask = node->src[1]; + if (std::string(mask->name).find("KQ_mask") != 0) { + throw std::runtime_error("Unexpected softmax node: " + std::string(mask->name)); } - token_len = node->src[1]->ne[0]; - past_token_len_from_inp_pos = ((int32_t*) (node->src[1]->data))[0]; - } - if (node->op == GGML_OP_CPY && ggml_is_contiguous(node)) { - assert(std::string(node->view_src->name).find("cache_k") == 0); - past_token_len = - (int64_t) (node->src[1]->op_params[0] / node->src[1]->nb[0] / m_head_size / m_num_heads_kv); - break; - } - if (node->op == GGML_OP_SET_ROWS && std::string(node->name).find("cache_k") == 0) { - assert(node->src[1]->type == GGML_TYPE_I64); - past_token_len = *(int64_t*) (node->src[1]->data); + attention_size = mask->ne[0]; break; } } - if (past_token_len == -1) { - throw std::runtime_error("Failed to find input \"cache_k\" in the graph"); - } - if (past_token_len != past_token_len_from_inp_pos) { - GGML_LOG_DEBUG("Mismatch between past_token_len from cache_k and inp_pos: %ld vs %ld\n", - past_token_len, - past_token_len_from_inp_pos); - } - { - std::string name = "past_token_len"; - auto param_node = std::make_shared(ov::element::i64, ov::Shape{1}); - param_node->set_friendly_name(name); - param_node->output(0).get_tensor().set_names({name}); - m_model_extra_inputs[name] = param_node; - - auto tensor = std::make_shared(ov::element::i64, ov::Shape{1}); - *tensor->data() = past_token_len; - m_model_extra_input_values[name] = tensor; - } - { - int64_t total_token_len = token_len + past_token_len; - attention_size = GGML_PAD(total_token_len, 32); std::string name = "attention_size"; auto param_node = std::make_shared(ov::element::i64, ov::Shape{1}); param_node->set_friendly_name(name); @@ -663,7 +613,6 @@ const std::string& GgmlOvDecoder::get_op_type() const { {GGML_OP_ADD, "GGML_OP_ADD" }, {GGML_OP_ADD1, "GGML_OP_ADD1" }, {GGML_OP_CONT, "GGML_OP_CONT" }, - {GGML_OP_CPY, "GGML_OP_CPY" }, {GGML_OP_DIV, "GGML_OP_DIV" }, {GGML_OP_DUP, "GGML_OP_DUP" }, {GGML_OP_GET_ROWS, "GGML_OP_GET_ROWS" }, diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp index 14999ba66b..fb5451be32 100644 --- a/ggml/src/ggml-openvino/ggml-openvino.cpp +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -328,10 +328,21 @@ static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, con static const std::set supported_types{ GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_I64, GGML_TYPE_I32}; - static const std::set supported_ops{GGML_OP_NONE, GGML_OP_ADD, GGML_OP_MUL, GGML_OP_MUL_MAT, - GGML_OP_VIEW, GGML_OP_CONT, GGML_OP_CPY, GGML_OP_RESHAPE, - GGML_OP_PERMUTE, GGML_OP_TRANSPOSE, GGML_OP_GET_ROWS, GGML_OP_ROPE, - GGML_OP_RMS_NORM, GGML_OP_SCALE, GGML_OP_SOFT_MAX, GGML_OP_SET_ROWS}; + static const std::set supported_ops{GGML_OP_NONE, + GGML_OP_ADD, + GGML_OP_MUL, + GGML_OP_MUL_MAT, + GGML_OP_VIEW, + GGML_OP_CONT, + GGML_OP_RESHAPE, + GGML_OP_PERMUTE, + GGML_OP_TRANSPOSE, + GGML_OP_GET_ROWS, + GGML_OP_ROPE, + GGML_OP_RMS_NORM, + GGML_OP_SCALE, + GGML_OP_SOFT_MAX, + GGML_OP_SET_ROWS}; static const std::set supported_unary_ops{ GGML_UNARY_OP_SILU, }; diff --git a/ggml/src/ggml-openvino/openvino/op/cpy.cpp b/ggml/src/ggml-openvino/openvino/op/cpy.cpp deleted file mode 100644 index 553f3c7966..0000000000 --- a/ggml/src/ggml-openvino/openvino/op/cpy.cpp +++ /dev/null @@ -1,73 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "../node_context.hpp" -#include "../op_table.hpp" -#include "../utils.hpp" - -namespace ov { -namespace frontend { -namespace ggml { -namespace op { - -OutputVector translate_cpy(const NodeContext& context) { - num_inputs_check(context, 2, 2); - - int op_case = context.get_op_case(); - FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2, "Unsupported CPY case"); - - auto src0 = context.get_input(0); - auto src1 = context.get_input(1); - - src0 = std::make_shared(src0, context.get_input_type(1)); - ov::Output res; - - if (context.is_static() && context.is_first_token()) { - res = src0; - return rename_outputs_with_suffix({res}, context.get_name()); - } - - if (op_case == 1) { - // Write K to cache_k - int64_t head_size = context.get_head_size(); - int64_t num_heads_kv = context.get_num_heads_kv(); - auto src0_reshape_shape = - ov::op::v0::Constant::create(ov::element::i64, {3}, std::vector{-1, num_heads_kv, head_size}); - src0 = std::make_shared(src0, src0_reshape_shape, false); - auto indices = context.get_input("update_indices_k"); - auto updated = std::make_shared(src1, indices, src0); - res = std::make_shared(updated, std::make_shared(src1), false); - } else { - // Write V to cache_v - auto flattend_src0 = - std::make_shared(src0, - ov::op::v0::Constant::create(element::i64, Shape{1}, {-1}), - false); - auto src0_shape = context.get_input_shape(0).to_shape(); - int64_t total_head_size = src0_shape[1]; - auto reshaped_src1 = std::make_shared( - src1, - ov::op::v0::Constant::create(ov::element::i64, {2}, std::vector{total_head_size, -1}), - false); - auto indices = context.get_input("update_indices_v"); - auto updated = std::make_shared(reshaped_src1, indices, flattend_src0); - res = std::make_shared(updated, std::make_shared(src1), false); - } - - return rename_outputs_with_suffix({res}, context.get_name()); -} - -} // namespace op -} // namespace ggml -} // namespace frontend -} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op_table.cpp b/ggml/src/ggml-openvino/openvino/op_table.cpp index 744f355a54..ce4b01c3b5 100644 --- a/ggml/src/ggml-openvino/openvino/op_table.cpp +++ b/ggml/src/ggml-openvino/openvino/op_table.cpp @@ -19,7 +19,6 @@ std::unordered_map get_supported_ops() { {"GGML_OP_ADD", op::translate_1to1_match_2_inputs }, {"GGML_OP_ADD1", op::translate_1to1_match_2_inputs }, {"GGML_OP_CONT", op::translate_cont }, - {"GGML_OP_CPY", op::translate_cpy }, {"GGML_OP_DIV", op::translate_1to1_match_2_inputs }, {"GGML_OP_GET_ROWS", op::translate_get_rows }, {"GGML_OP_MUL", op::translate_1to1_match_2_inputs}, diff --git a/ggml/src/ggml-openvino/openvino/op_table.hpp b/ggml/src/ggml-openvino/openvino/op_table.hpp index 631812aaa3..332930c3ac 100644 --- a/ggml/src/ggml-openvino/openvino/op_table.hpp +++ b/ggml/src/ggml-openvino/openvino/op_table.hpp @@ -12,7 +12,6 @@ namespace op { GGML_OP_CONVERTER(translate_add); GGML_OP_CONVERTER(translate_cont); -GGML_OP_CONVERTER(translate_cpy); GGML_OP_CONVERTER(translate_get_rows); GGML_OP_CONVERTER(translate_mul); GGML_OP_CONVERTER(translate_mulmat); diff --git a/ggml/src/ggml-openvino/openvino/translate_session.cpp b/ggml/src/ggml-openvino/openvino/translate_session.cpp index daef12fb90..a09247347f 100644 --- a/ggml/src/ggml-openvino/openvino/translate_session.cpp +++ b/ggml/src/ggml-openvino/openvino/translate_session.cpp @@ -76,65 +76,6 @@ void add_token_len(TensorMap& tensor_map) { tensor_map.insert({"token_len", token_len->output(0)}); } -void add_kv_update_indices(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) { - // cache_k layout: [S, N, H] (seq, num_heads, head_size) - // cache_v layout: [N, H, S] (num_heads, head_size, seq) - // When writing to cache_v, cache should be reshaped to [N*H, S] and v-curr should be flattened - auto past_token_len = tensor_map.at("past_token_len").get_node_shared_ptr(); - auto token_len = tensor_map.at("token_len").get_node_shared_ptr(); - - Output update_indices_k; - Output update_indices_v; - - auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); - auto zero_scalar = ov::op::v0::Constant::create(ov::element::i64, {}, {0}); - auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); - auto one_scalar = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {1}); - auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2}); - - auto past_token_len_scalar = std::make_shared(past_token_len, zero); - auto token_len_scalar = std::make_shared(token_len, zero); - auto total_token_len_scalar = std::make_shared(past_token_len_scalar, token_len_scalar); - - Output update_indices = std::make_shared( - past_token_len_scalar, total_token_len_scalar, one_scalar, ov::element::i64); - if (ggml_model_decoder.is_static()) { - update_indices = past_token_len; - } - - update_indices_k = std::make_shared(update_indices, one); - update_indices_k.get_node_shared_ptr()->set_friendly_name("update_indices_k"); - tensor_map.insert({"update_indices_k", update_indices_k}); - - auto total_head_size = ggml_model_decoder.get_num_heads_kv() * ggml_model_decoder.get_head_size(); - auto total_head_size_node = ov::op::v0::Constant::create(ov::element::i64, {1}, {total_head_size}); - auto total_head_size_scalar = std::make_shared(total_head_size_node, zero); - - // 1D tensor of shape [total_head_size], values starting from 0 - auto range_row = - std::make_shared(zero_scalar, total_head_size_scalar, one_scalar, ov::element::i64); - auto range_row_reshaped = - std::make_shared(range_row, ov::op::v0::Constant::create(ov::element::i64, {2}, {1, 2})); - auto row_indices = std::make_shared( - range_row_reshaped, - std::make_shared(ov::OutputVector{total_head_size_node, token_len, one}, 0)); - - // 1D tensor of shape [token_len], values starting from past_token_len - auto range_col = update_indices; - auto range_col_reshaped = - std::make_shared(range_col, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 2})); - auto col_indices = std::make_shared( - range_col_reshaped, - std::make_shared(ov::OutputVector{total_head_size_node, token_len, one}, 0)); - - // Stack row_indices and col_indices along last axis: [total_head_size, token_len, 2] - update_indices_v = std::make_shared(OutputVector{row_indices, col_indices}, 2); - update_indices_v = std::make_shared( - update_indices_v, ov::op::v0::Constant::create(ov::element::i64, {2}, std::vector{-1, 2}), false); - update_indices_v.get_node_shared_ptr()->set_friendly_name("update_indices_v"); - tensor_map.insert({"update_indices_v", update_indices_v}); -} - void add_rope_sin_cos(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) { int32_t* rope_params = ggml_model_decoder.get_rope_params(); auto inp_pos = tensor_map.at("inp_pos").get_node_shared_ptr(); @@ -156,7 +97,6 @@ void add_rope_sin_cos(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) { // Create common patterns void preprocess(TensorMap& tensor_map, GgmlDecoder& ggml_model_decoder) { add_token_len(tensor_map); - add_kv_update_indices(tensor_map, ggml_model_decoder); add_rope_sin_cos(tensor_map, ggml_model_decoder); }