Refactor: remove past_token_len from extra_inputs
This commit is contained in:
parent
acf358d1ce
commit
0fa7a5efef
|
|
@ -249,26 +249,16 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor* src) co
|
||||||
}
|
}
|
||||||
|
|
||||||
void GgmlOvDecoder::add_extra_inputs() {
|
void GgmlOvDecoder::add_extra_inputs() {
|
||||||
int64_t past_token_len = -1;
|
|
||||||
// attention_size not used for NPU
|
// attention_size not used for NPU
|
||||||
int64_t attention_size = -1;
|
int64_t attention_size = -1;
|
||||||
|
|
||||||
|
int64_t past_token_len = -1;
|
||||||
for (const auto& node : m_nodes) {
|
for (const auto& node : m_nodes) {
|
||||||
if (node->op == GGML_OP_CPY && ggml_is_contiguous(node)) {
|
if (node->op == GGML_OP_CPY && ggml_is_contiguous(node)) {
|
||||||
assert(std::string(node->view_src->name).find("cache_k") == 0);
|
assert(std::string(node->view_src->name).find("cache_k") == 0);
|
||||||
int64_t head_size = node->src[0]->ne[0];
|
int64_t head_size = node->src[0]->ne[0];
|
||||||
int64_t num_heads = node->src[0]->ne[1];
|
int64_t num_heads = node->src[0]->ne[1];
|
||||||
past_token_len = (int64_t)(node->src[1]->op_params[0] / node->src[1]->nb[0] / head_size / num_heads);
|
past_token_len = (int64_t) (node->src[1]->op_params[0] / node->src[1]->nb[0] / head_size / num_heads);
|
||||||
|
|
||||||
std::string name = "past_token_len";
|
|
||||||
auto param_node = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::Shape{1});
|
|
||||||
param_node->set_friendly_name(name);
|
|
||||||
param_node->output(0).get_tensor().set_names({name});
|
|
||||||
m_model_extra_inputs[name] = param_node;
|
|
||||||
|
|
||||||
auto tensor = std::make_shared<ov::Tensor>(ov::element::i64, ov::Shape{1});
|
|
||||||
*tensor->data<int64_t>() = past_token_len;
|
|
||||||
m_model_extra_input_values[name] = tensor;
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -28,12 +28,6 @@ OutputVector translate_cpy(const NodeContext& context) {
|
||||||
|
|
||||||
auto src0 = context.get_input(0);
|
auto src0 = context.get_input(0);
|
||||||
auto src1 = context.get_input(1);
|
auto src1 = context.get_input(1);
|
||||||
auto token_len = context.get_input("token_len");
|
|
||||||
auto past_token_len = context.get_input("past_token_len");
|
|
||||||
|
|
||||||
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
|
|
||||||
auto token_len_scalar = std::make_shared<ov::op::v0::Squeeze>(token_len, zero);
|
|
||||||
auto past_token_len_scalar = std::make_shared<ov::op::v0::Squeeze>(past_token_len, zero);
|
|
||||||
|
|
||||||
src0 = std::make_shared<ov::op::v0::Convert>(src0, context.get_input_type(1));
|
src0 = std::make_shared<ov::op::v0::Convert>(src0, context.get_input_type(1));
|
||||||
ov::Output<Node> res;
|
ov::Output<Node> res;
|
||||||
|
|
@ -43,12 +37,6 @@ OutputVector translate_cpy(const NodeContext& context) {
|
||||||
return rename_outputs_with_suffix({res}, context.get_name());
|
return rename_outputs_with_suffix({res}, context.get_name());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto src0_shape = context.get_input_shape(0).to_shape();
|
|
||||||
auto output_shape = context.get_output_shape(0).to_shape();
|
|
||||||
|
|
||||||
std::vector<size_t> input0_strides = context.get_input_stride(0);
|
|
||||||
std::vector<size_t> output_strides = context.get_output_stride(0);
|
|
||||||
|
|
||||||
if (op_case == 1) {
|
if (op_case == 1) {
|
||||||
// Write K to cache_k
|
// Write K to cache_k
|
||||||
auto indices = context.get_input("update_indices_k");
|
auto indices = context.get_input("update_indices_k");
|
||||||
|
|
@ -60,6 +48,7 @@ OutputVector translate_cpy(const NodeContext& context) {
|
||||||
std::make_shared<ov::op::v1::Reshape>(src0,
|
std::make_shared<ov::op::v1::Reshape>(src0,
|
||||||
ov::op::v0::Constant::create(element::i64, Shape{1}, {-1}),
|
ov::op::v0::Constant::create(element::i64, Shape{1}, {-1}),
|
||||||
false);
|
false);
|
||||||
|
auto src0_shape = context.get_input_shape(0).to_shape();
|
||||||
int64_t total_head_size = src0_shape[1];
|
int64_t total_head_size = src0_shape[1];
|
||||||
auto reshaped_src1 = std::make_shared<ov::op::v1::Reshape>(
|
auto reshaped_src1 = std::make_shared<ov::op::v1::Reshape>(
|
||||||
src1,
|
src1,
|
||||||
|
|
|
||||||
|
|
@ -72,7 +72,6 @@ void add_kv_update_indices(TensorMap& tensor_map, GgmlDecoder& ggml_model_decode
|
||||||
// cache_v layout: [N, H, S] (num_heads, head_size, seq)
|
// cache_v layout: [N, H, S] (num_heads, head_size, seq)
|
||||||
// When writing to cache_v, cache should be reshaped to [N*H, S] and v-curr should be flattened
|
// When writing to cache_v, cache should be reshaped to [N*H, S] and v-curr should be flattened
|
||||||
auto inp_pos = tensor_map.at("inp_pos").get_node_shared_ptr();
|
auto inp_pos = tensor_map.at("inp_pos").get_node_shared_ptr();
|
||||||
auto past_token_len = tensor_map.at("past_token_len").get_node_shared_ptr();
|
|
||||||
auto token_len = tensor_map.at("token_len").get_node_shared_ptr();
|
auto token_len = tensor_map.at("token_len").get_node_shared_ptr();
|
||||||
|
|
||||||
std::shared_ptr<ov::Node> update_indices_k;
|
std::shared_ptr<ov::Node> update_indices_k;
|
||||||
|
|
@ -84,12 +83,8 @@ void add_kv_update_indices(TensorMap& tensor_map, GgmlDecoder& ggml_model_decode
|
||||||
auto one_scalar = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {1});
|
auto one_scalar = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {1});
|
||||||
auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
|
auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
|
||||||
|
|
||||||
if (ggml_model_decoder.is_static()) {
|
update_indices_k =
|
||||||
update_indices_k = past_token_len;
|
std::make_shared<ov::op::v0::Squeeze>(inp_pos, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1}));
|
||||||
} else {
|
|
||||||
update_indices_k =
|
|
||||||
std::make_shared<ov::op::v0::Squeeze>(inp_pos, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1}));
|
|
||||||
}
|
|
||||||
update_indices_k = std::make_shared<ov::op::v0::Unsqueeze>(update_indices_k, one);
|
update_indices_k = std::make_shared<ov::op::v0::Unsqueeze>(update_indices_k, one);
|
||||||
update_indices_k->set_friendly_name("update_indices_k");
|
update_indices_k->set_friendly_name("update_indices_k");
|
||||||
tensor_map.insert({"update_indices_k", update_indices_k->output(0)});
|
tensor_map.insert({"update_indices_k", update_indices_k->output(0)});
|
||||||
|
|
@ -108,14 +103,8 @@ void add_kv_update_indices(TensorMap& tensor_map, GgmlDecoder& ggml_model_decode
|
||||||
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{total_head_size_node, token_len, one}, 0));
|
std::make_shared<ov::op::v0::Concat>(ov::OutputVector{total_head_size_node, token_len, one}, 0));
|
||||||
|
|
||||||
// 1D tensor of shape [token_len], values starting from past_token_len
|
// 1D tensor of shape [token_len], values starting from past_token_len
|
||||||
std::shared_ptr<ov::Node> range_col;
|
auto range_col =
|
||||||
if (ggml_model_decoder.is_static()) {
|
std::make_shared<ov::op::v0::Squeeze>(inp_pos, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1}));
|
||||||
// aka inp_pos
|
|
||||||
range_col = past_token_len;
|
|
||||||
} else {
|
|
||||||
range_col =
|
|
||||||
std::make_shared<ov::op::v0::Squeeze>(inp_pos, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1}));
|
|
||||||
}
|
|
||||||
auto range_col_reshaped =
|
auto range_col_reshaped =
|
||||||
std::make_shared<ov::op::v0::Unsqueeze>(range_col, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 2}));
|
std::make_shared<ov::op::v0::Unsqueeze>(range_col, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 2}));
|
||||||
auto col_indices = std::make_shared<ov::op::v3::Broadcast>(
|
auto col_indices = std::make_shared<ov::op::v3::Broadcast>(
|
||||||
|
|
@ -233,10 +222,9 @@ void TranslateSession::apply_transformations(const std::shared_ptr<Model>& model
|
||||||
const auto kv_param_res_names = ggml_model_decoder->get_kv_param_res_names();
|
const auto kv_param_res_names = ggml_model_decoder->get_kv_param_res_names();
|
||||||
const auto kv_param_res_pairs = get_kv_param_res_pairs(model, kv_param_res_names);
|
const auto kv_param_res_pairs = get_kv_param_res_pairs(model, kv_param_res_names);
|
||||||
manager.register_pass<ov::pass::MakeStateful>(kv_param_res_pairs);
|
manager.register_pass<ov::pass::MakeStateful>(kv_param_res_pairs);
|
||||||
|
|
||||||
manager.register_pass<pass::FuseToSDPA>();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
manager.register_pass<pass::FuseToSDPA>();
|
||||||
manager.run_passes(model);
|
manager.run_passes(model);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue