diff --git a/gemma.cc b/gemma.cc index 35a4a47..7c9d187 100644 --- a/gemma.cc +++ b/gemma.cc @@ -25,6 +25,8 @@ #include "compression/compress-inl.h" // copybara:import_next_line:gemma_cpp #include "ops.h" +// copybara:import_next_line:gemma_cpp +#include "util/args.h" // Path #include "hwy/contrib/matvec/matvec-inl.h" #include "hwy/highway.h" #include "hwy/profiler.h" @@ -813,8 +815,9 @@ void GemmaImpl::Generate( } Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, - const Path& weights_path, Model model_type, - hwy::ThreadPool& pool) { + const Path& weights_path, Model model_type, ModelTraining training, + hwy::ThreadPool& pool) + : model_training(training) { std::unique_ptr tokenizer; { PROFILER_ZONE("Startup.tokenizer"); diff --git a/gemma.h b/gemma.h index f5e88fa..d52356e 100644 --- a/gemma.h +++ b/gemma.h @@ -16,12 +16,9 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_H_ -#include -#include #include #include #include -#include #include // copybara:import_next_line:gemma_cpp @@ -31,7 +28,7 @@ #include "configs.h" // kSeqLen // copybara:end // copybara:import_next_line:gemma_cpp -#include "util/args.h" // ArgsBase +#include "util/args.h" // Path // copybara:end #include "hwy/aligned_allocator.h" #include "hwy/base.h" // hwy::bfloat16_t @@ -75,7 +72,8 @@ struct GemmaInterface; struct Gemma { Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, - const Path& weights_path, Model model_type, hwy::ThreadPool& pool); + const Path& weights_path, Model model_type, ModelTraining training, + hwy::ThreadPool& pool); ~Gemma(); // must be defined after GemmaInterface's dtor is defined. const sentencepiece::SentencePieceProcessor* Tokenizer() const; std::unique_ptr impl_; diff --git a/run.cc b/run.cc index fcf974b..ce9de93 100644 --- a/run.cc +++ b/run.cc @@ -235,7 +235,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { } gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, loader.weights, - loader.ModelType(), pool); + loader.ModelType(), loader.ModelTraining(), pool); auto kv_cache = CreateKVCache(loader.ModelType());