diff --git a/examples/hello_world/CMakeLists.txt b/examples/hello_world/CMakeLists.txt index 088af84..43159d7 100644 --- a/examples/hello_world/CMakeLists.txt +++ b/examples/hello_world/CMakeLists.txt @@ -22,7 +22,18 @@ FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.gi FetchContent_MakeAvailable(highway) FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) FetchContent_MakeAvailable(sentencepiece) -FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG e781007836ec034236e90cc4d313d0a8c481bce6) + + +# Allow for both local and remote building) +option(BUILD_MODE "'local' or 'remote' git fetch for builds") +if (NOT BUILD_MODE) + set(BUILD_MODE "remote") +endif() +if (BUILD_MODE STREQUAL "local") + FetchContent_Declare(gemma SOURCE_DIR ../../..) +else() + FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG 42e53e2da89f80dc46399c7037fbbfb15cdc3de3) +endif() FetchContent_MakeAvailable(gemma) if(NOT CMAKE_BUILD_TYPE) diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index fd6c762..8b5de24 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -22,15 +22,17 @@ std::vector tokenize( int main(int argc, char** argv) { gcpp::LoaderArgs loader(argc, argv); - // A rough heuristic number of threads to use + // Rough heuristic for the number of threads to use size_t num_threads = static_cast(std::clamp( static_cast(std::thread::hardware_concurrency()) - 2, 1, 18)); hwy::ThreadPool pool(num_threads); - // Instantiate model - gcpp::Gemma model(loader, pool); + // Instantiate model and KV Cache + gcpp::Gemma model(loader.tokenizer, loader.cache, loader.ModelType(), pool); + auto kv_cache = CreateKVCache(loader.ModelType()); + size_t pos = 0; // KV Cache position - // Setup random number generator + // Initialize random number generator std::mt19937 gen; std::random_device rd; gen.seed(rd()); @@ -39,7 +41,6 @@ int main(int argc, char** argv) { std::vector tokens = tokenize("Write a greeting to the world.", model.Tokenizer()); size_t ntokens = tokens.size(); - size_t pos = 0; // Callback auto stream_token = [&pos, &gen, &ntokens, tokenizer = model.Tokenizer()]( @@ -60,6 +61,7 @@ int main(int argc, char** argv) { .max_generated_tokens = 1024, .temperature = 1.0, .verbosity = 0}, - tokens, /*KV cache position = */ 0, pool, stream_token, gen); + tokens, /*KV cache position = */ 0, kv_cache, pool, + stream_token, gen); std::cout << std::endl; } diff --git a/gemma.cc b/gemma.cc index 9c2df6c..15b3c26 100644 --- a/gemma.cc +++ b/gemma.cc @@ -285,7 +285,6 @@ struct GemmaImpl : public GemmaInterface { int verbosity) override; std::unique_ptr tokenizer; - hwy::AlignedFreeUniquePtr compressed_weights; hwy::AlignedUniquePtr> prefill; hwy::AlignedUniquePtr> state; @@ -803,15 +802,15 @@ void GemmaImpl::Generate( kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); } -Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) { - const Model model_type = args.ModelType(); - model_training = args.ModelTraining(); +Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, + const Path& weights_path, Model model_type, + hwy::ThreadPool& pool) { PROFILER_ZONE("Startup.tokenizer"); std::unique_ptr tokenizer = std::make_unique(); - HWY_ASSERT(tokenizer->Load(args.tokenizer.path).ok()); + HWY_ASSERT(tokenizer->Load(tokenizer_path.path).ok()); auto compressed_weights = HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)( - args.ModelType(), args.model, args.cache, pool); + model_type, weights_path, compressed_weights_path, pool); switch (model_type) { case Model::GEMMA_2B: impl_.reset( @@ -825,6 +824,12 @@ Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) { HWY_ABORT("Model type %d unknown.", static_cast(model_type)); } } + +Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, + Model model_type, hwy::ThreadPool& pool) + : Gemma(tokenizer_path, compressed_weights_path, Path{""}, model_type, + pool) {} + Gemma::~Gemma() = default; // after GemmaInterface is defined const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const { diff --git a/gemma.h b/gemma.h index 3bb4c28..7c08412 100644 --- a/gemma.h +++ b/gemma.h @@ -192,7 +192,10 @@ struct InferenceArgs : public ArgsBase { struct GemmaInterface; struct Gemma { - Gemma(const LoaderArgs& args, hwy::ThreadPool& pool); + Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, + const Path& weights_path, Model model_type, hwy::ThreadPool& pool); + Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, + Model model_type, hwy::ThreadPool& pool); ~Gemma(); // must be defined after GemmaInterface's dtor is defined. const sentencepiece::SentencePieceProcessor* Tokenizer() const; std::unique_ptr impl_; diff --git a/run.cc b/run.cc index 64b6399..cdeb95e 100644 --- a/run.cc +++ b/run.cc @@ -235,7 +235,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { [](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); }); } - gcpp::Gemma model(loader, pool); + gcpp::Gemma model(loader.tokenizer, loader.cache, loader.ModelType(), pool); auto kv_cache = CreateKVCache(loader.ModelType());