add mark decomp pass
This commit is contained in:
parent
93b2d09a2d
commit
1a19566b23
|
|
@ -0,0 +1,29 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "mark_decompression_convert_constant_folding.hpp"
|
||||||
|
#include "openvino/pass/matcher_pass.hpp"
|
||||||
|
#include "openvino/core/visibility.hpp"
|
||||||
|
|
||||||
|
#ifdef OPENVINO_STATIC_LIBRARY
|
||||||
|
# define TRANSFORMATIONS_API
|
||||||
|
#else
|
||||||
|
# ifdef IMPLEMENT_OPENVINO_API
|
||||||
|
# define TRANSFORMATIONS_API OPENVINO_CORE_EXPORTS
|
||||||
|
# else
|
||||||
|
# define TRANSFORMATIONS_API OPENVINO_CORE_IMPORTS
|
||||||
|
# endif // IMPLEMENT_OPENVINO_API
|
||||||
|
#endif // OPENVINO_STATIC_LIBRARY
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace pass {
|
||||||
|
|
||||||
|
class TRANSFORMATIONS_API MarkCompressedFloatConstants;
|
||||||
|
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ov
|
||||||
|
|
||||||
|
class ov::pass::MarkCompressedFloatConstants : public MatcherPass {
|
||||||
|
public:
|
||||||
|
OPENVINO_MATCHER_PASS_RTTI("MarkCompressedFloatConstants");
|
||||||
|
MarkCompressedFloatConstants();
|
||||||
|
};
|
||||||
|
|
@ -28,6 +28,7 @@
|
||||||
#include "ggml-openvino/openvino/utils.hpp"
|
#include "ggml-openvino/openvino/utils.hpp"
|
||||||
#include "input_model.hpp"
|
#include "input_model.hpp"
|
||||||
#include "pass/fuse_to_sdpa.hpp"
|
#include "pass/fuse_to_sdpa.hpp"
|
||||||
|
#include "pass/mark_decompression_convert_constant_folding.hpp"
|
||||||
|
|
||||||
namespace ov {
|
namespace ov {
|
||||||
namespace frontend {
|
namespace frontend {
|
||||||
|
|
@ -259,6 +260,8 @@ std::shared_ptr<Model> TranslateSession::apply_transformations(std::shared_ptr<M
|
||||||
{
|
{
|
||||||
ov::pass::Manager manager;
|
ov::pass::Manager manager;
|
||||||
manager.set_per_pass_validation(true);
|
manager.set_per_pass_validation(true);
|
||||||
|
manager.register_pass<ov::pass::MarkCompressedFloatConstants>();
|
||||||
|
manager.register_pass<ov::pass::ConstantFolding>();
|
||||||
|
|
||||||
if (!ggml_model_decoder->is_static()) {
|
if (!ggml_model_decoder->is_static()) {
|
||||||
const auto kv_param_res_names = ggml_model_decoder->get_kv_param_res_names();
|
const auto kv_param_res_names = ggml_model_decoder->get_kv_param_res_names();
|
||||||
|
|
@ -267,7 +270,7 @@ std::shared_ptr<Model> TranslateSession::apply_transformations(std::shared_ptr<M
|
||||||
}
|
}
|
||||||
|
|
||||||
// SDPA is even worse on performance
|
// SDPA is even worse on performance
|
||||||
// manager.register_pass<pass::FuseToSDPA>();
|
manager.register_pass<pass::FuseToSDPA>();
|
||||||
manager.run_passes(model);
|
manager.run_passes(model);
|
||||||
}
|
}
|
||||||
auto preprocessor = ov::preprocess::PrePostProcessor(model);
|
auto preprocessor = ov::preprocess::PrePostProcessor(model);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue