diff --git a/gemma.cc b/gemma.cc index 7c9d187..f2a2275 100644 --- a/gemma.cc +++ b/gemma.cc @@ -25,8 +25,6 @@ #include "compression/compress-inl.h" // copybara:import_next_line:gemma_cpp #include "ops.h" -// copybara:import_next_line:gemma_cpp -#include "util/args.h" // Path #include "hwy/contrib/matvec/matvec-inl.h" #include "hwy/highway.h" #include "hwy/profiler.h" @@ -815,9 +813,8 @@ void GemmaImpl::Generate( } Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, - const Path& weights_path, Model model_type, ModelTraining training, - hwy::ThreadPool& pool) - : model_training(training) { + const Path& weights_path, Model model_type, + hwy::ThreadPool& pool) { std::unique_ptr tokenizer; { PROFILER_ZONE("Startup.tokenizer"); @@ -842,6 +839,11 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, } } +Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, + Model model_type, hwy::ThreadPool& pool) + : Gemma(tokenizer_path, compressed_weights_path, Path{""}, model_type, + pool) {} + Gemma::~Gemma() = default; // after GemmaInterface is defined const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const { diff --git a/gemma.h b/gemma.h index d52356e..cdd4873 100644 --- a/gemma.h +++ b/gemma.h @@ -16,9 +16,12 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_H_ +#include +#include #include #include #include +#include #include // copybara:import_next_line:gemma_cpp @@ -28,7 +31,7 @@ #include "configs.h" // kSeqLen // copybara:end // copybara:import_next_line:gemma_cpp -#include "util/args.h" // Path +#include "util/args.h" // ArgsBase // copybara:end #include "hwy/aligned_allocator.h" #include "hwy/base.h" // hwy::bfloat16_t @@ -72,8 +75,9 @@ struct GemmaInterface; struct Gemma { Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, - const Path& weights_path, Model model_type, ModelTraining training, - hwy::ThreadPool& pool); + const Path& weights_path, Model model_type, hwy::ThreadPool& pool); + Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, + Model model_type, hwy::ThreadPool& pool); ~Gemma(); // must be defined after GemmaInterface's dtor is defined. const sentencepiece::SentencePieceProcessor* Tokenizer() const; std::unique_ptr impl_; diff --git a/run.cc b/run.cc index ce9de93..b08e4ca 100644 --- a/run.cc +++ b/run.cc @@ -234,8 +234,8 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { [](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); }); } - gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, loader.weights, - loader.ModelType(), loader.ModelTraining(), pool); + gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, + loader.ModelType(), pool); auto kv_cache = CreateKVCache(loader.ModelType());