NPU support version 2: prefill + kvcache

This commit is contained in:
Yu, Zijun 2025-06-03 14:22:51 +08:00 committed by Mustafa Cavus
parent 34531abce4
commit d9ca8f5dbe
5 changed files with 52 additions and 28 deletions

View File

@ -222,11 +222,11 @@ void GgmlOvDecoder::add_extra_inputs() {
past_token_len = (int64_t)(node->src[1]->op_params[0] / node->src[1]->nb[0] / head_size / num_heads);
std::string name = "past_token_len";
auto param_node = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::Shape{});
auto param_node = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::Shape{1});
param_node->set_friendly_name(name);
m_model_extra_inputs[name] = param_node;
auto tensor = std::make_shared<ov::Tensor>(ov::element::i64, ov::Shape{});
auto tensor = std::make_shared<ov::Tensor>(ov::element::i64, ov::Shape{1});
*tensor->data<int64_t>() = past_token_len;
m_model_extra_input_values[name] = tensor;
break;

View File

@ -34,7 +34,7 @@ OutputVector translate_cpy(const NodeContext& context) {
auto src0 = context.get_input(0);
auto src1 = context.get_input(1);
auto past_token_len_scalar = context.get_input("past_token_len");
auto past_token_len = context.get_input("past_token_len");
src0 = std::make_shared<ov::op::v0::Convert>(src0, context.get_input_type(1));
ov::Output<Node> res;
@ -68,18 +68,16 @@ OutputVector translate_cpy(const NodeContext& context) {
std::shared_ptr<ov::Node> indices;
if (context.is_static()) {
indices = past_token_len_scalar.get_node_shared_ptr();
indices = std::make_shared<ov::op::v0::Unsqueeze>(
indices,
ov::op::v0::Constant::create(ov::element::i64, {2}, std::vector<int64_t>{0, 1}));
indices = past_token_len.get_node_shared_ptr();
} else {
auto past_token_len_scalar = std::make_shared<ov::op::v0::Squeeze>(past_token_len, zero);
auto total_token_len_scalar = std::make_shared<ov::op::v1::Add>(past_token_len_scalar, token_len_scalar);
indices = std::make_shared<ov::op::v4::Range>(past_token_len_scalar,
total_token_len_scalar,
one_scalar,
ov::element::i64);
indices = std::make_shared<ov::op::v0::Unsqueeze>(indices, one);
}
indices = std::make_shared<ov::op::v0::Unsqueeze>(indices, one);
res = std::make_shared<ov::op::v3::ScatterNDUpdate>(reshaped_src1, indices, src0);
} else {
@ -108,11 +106,9 @@ OutputVector translate_cpy(const NodeContext& context) {
// 1D tensor of shape [token_len], values starting from past_token_len
std::shared_ptr<ov::Node> range_col;
if (context.is_static()) {
range_col = past_token_len_scalar.get_node_shared_ptr();
range_col = std::make_shared<ov::op::v0::Unsqueeze>(
range_col,
ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{0}));
range_col = past_token_len.get_node_shared_ptr();
} else {
auto past_token_len_scalar = std::make_shared<ov::op::v0::Squeeze>(past_token_len, zero);
auto total_token_len_scalar = std::make_shared<ov::op::v1::Add>(past_token_len_scalar, token_len_scalar);
range_col = std::make_shared<ov::op::v4::Range>(past_token_len_scalar,
total_token_len_scalar,

View File

@ -1,3 +1,4 @@
#include <climits>
#include <cstdint>
#include <memory>
#include <openvino/core/node.hpp>
@ -68,7 +69,7 @@ OutputVector translate_mulmat(const NodeContext& context) {
std::vector<int64_t> src0_original_shape(src0_original_shape_.begin(), src0_original_shape_.end());
if (context.is_static()) {
attention_size = ov::op::v0::Constant::create(ov::element::i64, {1}, {src0_original_shape[token_dim]});
attention_size = ov::op::v0::Constant::create(ov::element::i64, {1}, {INT_MAX});
}
src0_original_shape[token_dim] = -1;

View File

@ -1,6 +1,7 @@
#include "utils.h"
#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstddef>
#include <cstdint>
@ -70,15 +71,17 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c
ov::AnyMap config;
if (device == "NPU") {
config = {
{"NPU_COMPILATION_MODE_PARAMS", "compute-layers-with-higher-precision=ReduceMean"},
{"NPU_USE_NPUW", "YES"},
{"NPUW_DEVICES", "NPU"},
{"NPUW_FOLD", "YES"},
{"NPUW_DQ", "YES"},
{"NPUW_FUNCALL_ASYNC", "YES"},
{"NPUW_HOST_GATHER", "YES"},
{"NPUW_WEIGHTS_BANK", "shared"},
// {"NPU_COMPILER_TYPE", "MLIR"},
{ "NPU_COMPILATION_MODE_PARAMS", "compute-layers-with-higher-precision=ReduceMean" },
{ "NPU_USE_NPUW", "YES" },
{ "NPUW_DEVICES", "NPU" },
{ "NPUW_FOLD", "YES" },
{ "NPUW_HOST_GATHER", "YES" },
{ "NPUW_DQ", "YES" },
{ "NPUW_FUNCALL_ASYNC", "YES" },
{ "NPUW_WEIGHTS_BANK", "shared" },
// Option 'CACHE_DIR' is not supported with MLIR compiler type
// {"NPUW_CACHE_DIR", getenv("GGML_OPENVINO_CACHE_DIR") ? getenv("GGML_OPENVINO_CACHE_DIR") : ""},
{ "NPU_COMPILER_TYPE", "MLIR" },
};
}
@ -102,15 +105,21 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c
int64_t conversion_end_time;
int64_t compile_end_time;
bool is_first_token = is_prefill(cgraph);
auto it = compiled_cache_prefill.find(cgraph);
bool is_first_token = it == compiled_cache_prefill.end();
if (!is_first_token) {
if (it != compiled_cache_prefill.end()) {
ggml_decoder = get_ggml_decoder(cgraph, is_static, false);
decoder_end_time = ggml_time_us();
if (is_static) {
model = compiled_cache_kvcache[cgraph].first;
compiled_model = compiled_cache_kvcache[cgraph].second;
if (is_first_token) {
model = compiled_cache_prefill[cgraph].first;
compiled_model = compiled_cache_prefill[cgraph].second;
} else {
model = compiled_cache_kvcache[cgraph].first;
compiled_model = compiled_cache_kvcache[cgraph].second;
}
} else {
model = it->second.first;
compiled_model = it->second.second;
@ -235,8 +244,6 @@ enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_c
}
auto end_time = ggml_time_us();
is_first_token = false;
if (getenv("GGML_OPENVINO_PROFILING")) {
GGML_LOG_INFO("GGML OpenVINO Backend: \n");
GGML_LOG_INFO(" - Graph decoder Time: %ld ms \n", (decoder_end_time - start_time) / 1000);
@ -305,3 +312,20 @@ void set_zero_diagonal(std::vector<float>& matrix, size_t dim) {
matrix[i * dim + i] = 0.0f;
}
}
bool is_prefill(struct ggml_cgraph * cgraph) {
for (int i = 0; i < cgraph->n_nodes; ++i) {
auto * op = cgraph->nodes[i];
for (int j = 0; j < GGML_MAX_SRC; ++j) {
auto* src = op->src[j];
if (src == nullptr) {
break;
}
if (std::string(src->name) == "inp_tokens") {
return src->ne[0] != 1;
}
}
}
GGML_LOG_ERROR("is_prefill: inp_tokens not found in cgraph");
throw std::runtime_error("is_prefill: inp_tokens not found in cgraph");
}

View File

@ -2,6 +2,7 @@
#include "ggml-backend-impl.h"
#include "ggml-decoder.h"
#include "ggml-impl.h"
enum ggml_status openvino_frontend_compute(ggml_backend_t backend, struct ggml_cgraph* cgraph);
@ -35,3 +36,5 @@ std::vector<T> pad_input(const ggml_tensor* tensor, size_t padded_rows, size_t p
}
void set_zero_diagonal(std::vector<float>& matrix, size_t dim);
bool is_prefill(struct ggml_cgraph * cgraph);