Connect "--weights" parameter to Gemma

PiperOrigin-RevId: 617323257
This commit is contained in:
Eric Ye 2024-03-20 00:07:47 +01:00 committed by Jan Wassenberg
parent fdc3812446
commit 6865819bb7
3 changed files with 1 additions and 8 deletions

View File

@ -839,11 +839,6 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
}
}
Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
Model model_type, hwy::ThreadPool& pool)
: Gemma(tokenizer_path, compressed_weights_path, Path{""}, model_type,
pool) {}
Gemma::~Gemma() = default; // after GemmaInterface is defined
const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const {

View File

@ -76,8 +76,6 @@ struct GemmaInterface;
struct Gemma {
Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
const Path& weights_path, Model model_type, hwy::ThreadPool& pool);
Gemma(const Path& tokenizer_path, const Path& compressed_weights_path,
Model model_type, hwy::ThreadPool& pool);
~Gemma(); // must be defined after GemmaInterface's dtor is defined.
const sentencepiece::SentencePieceProcessor* Tokenizer() const;
std::unique_ptr<GemmaInterface> impl_;

2
run.cc
View File

@ -234,7 +234,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
[](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); });
}
gcpp::Gemma model(loader.tokenizer, loader.compressed_weights,
gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, loader.weights,
loader.ModelType(), pool);
auto kv_cache = CreateKVCache(loader.ModelType());