From 39cd59caec8a66aa3e31205a6c0ca9d4daca82f5 Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Sun, 3 Mar 2024 10:33:29 -0500 Subject: [PATCH 01/20] [WIP] create skeleton for example frontend application --- examples/look/CMakeLists.txt | 45 ++++++++++++++++++++++++++++++++++ examples/look/build/.gitignore | 0 examples/look/run.cc | 28 +++++++++++++++++++++ 3 files changed, 73 insertions(+) create mode 100644 examples/look/CMakeLists.txt create mode 100644 examples/look/build/.gitignore create mode 100644 examples/look/run.cc diff --git a/examples/look/CMakeLists.txt b/examples/look/CMakeLists.txt new file mode 100644 index 0000000..8666ce7 --- /dev/null +++ b/examples/look/CMakeLists.txt @@ -0,0 +1,45 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.11) + +project(look) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +include(FetchContent) + +FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f) +FetchContent_MakeAvailable(highway) + +FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) +FetchContent_MakeAvailable(sentencepiece) + +FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG 808dbdc42b216c3ac1f1c40dfa638bcff24bbd2b) +FetchContent_MakeAvailable(gemma) + +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "Release") +endif() + +add_executable(look run.cc) +target_sources(look PRIVATE ${SOURCES}) +set_property(TARGET look PROPERTY CXX_STANDARD 17) +target_link_libraries(look hwy hwy_contrib sentencepiece libgemma) +target_include_directories(look PRIVATE ./) +FetchContent_GetProperties(sentencepiece) +target_include_directories(look PRIVATE ${sentencepiece_SOURCE_DIR}) +target_compile_definitions(look PRIVATE $<$:_CRT_SECURE_NO_WARNINGS NOMINMAX>) +target_compile_options(look PRIVATE $<$:-Wno-deprecated-declarations>) diff --git a/examples/look/build/.gitignore b/examples/look/build/.gitignore new file mode 100644 index 0000000..e69de29 diff --git a/examples/look/run.cc b/examples/look/run.cc new file mode 100644 index 0000000..85b0da4 --- /dev/null +++ b/examples/look/run.cc @@ -0,0 +1,28 @@ +#include + +// copybara:import_next_line:gemma_cpp +#include "compression/compress.h" +// copybara:end +// copybara:import_next_line:gemma_cpp +#include "gemma.h" // Gemma +// copybara:end +// copybara:import_next_line:gemma_cpp +#include "util/app.h" +// copybara:end +// copybara:import_next_line:gemma_cpp +#include "util/args.h" // HasHelp +// copybara:end +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/highway.h" +#include "hwy/per_target.h" +#include "hwy/profiler.h" +#include "hwy/timer.h" + +int main(int argc, char** argv) { + gcpp::LoaderArgs loader(argc, argv); + gcpp::AppArgs app(argc, argv); + hwy::ThreadPool pool(app.num_threads); + gcpp::Gemma model(loader, pool); + std::cout << "Done" << std::endl; +} From c378ac2c565081cfc9bf0e57531d9531948a6e32 Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Sun, 3 Mar 2024 11:36:48 -0500 Subject: [PATCH 02/20] [WIP] hello world example working. TODO: refactor interfaces to decouple arguments --- examples/{look => hello_world}/CMakeLists.txt | 18 ++--- .../{look => hello_world}/build/.gitignore | 0 examples/hello_world/run.cc | 74 +++++++++++++++++++ examples/look/run.cc | 28 ------- models/.gitignore | 0 5 files changed, 83 insertions(+), 37 deletions(-) rename examples/{look => hello_world}/CMakeLists.txt (70%) rename examples/{look => hello_world}/build/.gitignore (100%) create mode 100644 examples/hello_world/run.cc delete mode 100644 examples/look/run.cc create mode 100644 models/.gitignore diff --git a/examples/look/CMakeLists.txt b/examples/hello_world/CMakeLists.txt similarity index 70% rename from examples/look/CMakeLists.txt rename to examples/hello_world/CMakeLists.txt index 8666ce7..69e4e98 100644 --- a/examples/look/CMakeLists.txt +++ b/examples/hello_world/CMakeLists.txt @@ -14,7 +14,7 @@ cmake_minimum_required(VERSION 3.11) -project(look) +project(hello_world) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) @@ -34,12 +34,12 @@ if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE "Release") endif() -add_executable(look run.cc) -target_sources(look PRIVATE ${SOURCES}) -set_property(TARGET look PROPERTY CXX_STANDARD 17) -target_link_libraries(look hwy hwy_contrib sentencepiece libgemma) -target_include_directories(look PRIVATE ./) +add_executable(hello_world run.cc) +target_sources(hello_world PRIVATE ${SOURCES}) +set_property(TARGET hello_world PROPERTY CXX_STANDARD 17) +target_link_libraries(hello_world hwy hwy_contrib sentencepiece libgemma) +target_include_directories(hello_world PRIVATE ./) FetchContent_GetProperties(sentencepiece) -target_include_directories(look PRIVATE ${sentencepiece_SOURCE_DIR}) -target_compile_definitions(look PRIVATE $<$:_CRT_SECURE_NO_WARNINGS NOMINMAX>) -target_compile_options(look PRIVATE $<$:-Wno-deprecated-declarations>) +target_include_directories(hello_world PRIVATE ${sentencepiece_SOURCE_DIR}) +target_compile_definitions(hello_world PRIVATE $<$:_CRT_SECURE_NO_WARNINGS NOMINMAX>) +target_compile_options(hello_world PRIVATE $<$:-Wno-deprecated-declarations>) diff --git a/examples/look/build/.gitignore b/examples/hello_world/build/.gitignore similarity index 100% rename from examples/look/build/.gitignore rename to examples/hello_world/build/.gitignore diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc new file mode 100644 index 0000000..a017e22 --- /dev/null +++ b/examples/hello_world/run.cc @@ -0,0 +1,74 @@ +#include + +// copybara:import_next_line:gemma_cpp +#include "compression/compress.h" +// copybara:end +// copybara:import_next_line:gemma_cpp +#include "gemma.h" // Gemma +// copybara:end +// copybara:import_next_line:gemma_cpp +#include "util/app.h" +// copybara:end +// copybara:import_next_line:gemma_cpp +#include "util/args.h" // HasHelp +// copybara:end +#include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" +#include "hwy/highway.h" +#include "hwy/per_target.h" +#include "hwy/profiler.h" +#include "hwy/timer.h" + +std::vector tokenize(std::string prompt_string, const sentencepiece::SentencePieceProcessor& tokenizer) { + prompt_string = "user\n" + prompt_string + + "\nmodel\n"; + std::vector tokens; + HWY_ASSERT(tokenizer.Encode(prompt_string, &tokens).ok()); + tokens.insert(tokens.begin(), 2); // BOS token + return tokens; +} + +int main(int argc, char** argv) { + gcpp::InferenceArgs inference(argc, argv); + gcpp::LoaderArgs loader(argc, argv); + gcpp::AppArgs app(argc, argv); + hwy::ThreadPool pool(app.num_threads); + hwy::ThreadPool inner_pool(0); + gcpp::Gemma model(loader, pool); + + std::vector tokens = tokenize("Hello, how are you?", model.Tokenizer()); + + std::mt19937 gen; + std::random_device rd; + gen.seed(rd()); + + size_t ntokens = tokens.size(); + + size_t pos = 0; + + auto stream_token = [&pos, &gen, &ntokens, tokenizer = &model.Tokenizer()](int token, float) { + ++pos; + if (pos < ntokens) { + // print feedback + } else if (token != gcpp::EOS_ID) { + std::string token_text; + HWY_ASSERT(tokenizer->Decode(std::vector{token}, &token_text).ok()); + if (pos == ntokens + 1) { + // first token of response + token_text.erase(0, token_text.find_first_not_of(" \t\n\n")); + } + std::cout << token_text << std::flush; + } + return true; + }; + + inference.temperature = 1.0f; + inference.deterministic = true; + inference.multiturn = false; + + GenerateGemma( + model, inference, tokens, 0, pool, inner_pool, stream_token, + [](int) {return true;}, gen, 0); + + std::cout << std::endl; +} diff --git a/examples/look/run.cc b/examples/look/run.cc deleted file mode 100644 index 85b0da4..0000000 --- a/examples/look/run.cc +++ /dev/null @@ -1,28 +0,0 @@ -#include - -// copybara:import_next_line:gemma_cpp -#include "compression/compress.h" -// copybara:end -// copybara:import_next_line:gemma_cpp -#include "gemma.h" // Gemma -// copybara:end -// copybara:import_next_line:gemma_cpp -#include "util/app.h" -// copybara:end -// copybara:import_next_line:gemma_cpp -#include "util/args.h" // HasHelp -// copybara:end -#include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" -#include "hwy/highway.h" -#include "hwy/per_target.h" -#include "hwy/profiler.h" -#include "hwy/timer.h" - -int main(int argc, char** argv) { - gcpp::LoaderArgs loader(argc, argv); - gcpp::AppArgs app(argc, argv); - hwy::ThreadPool pool(app.num_threads); - gcpp::Gemma model(loader, pool); - std::cout << "Done" << std::endl; -} diff --git a/models/.gitignore b/models/.gitignore new file mode 100644 index 0000000..e69de29 From 10f7a086aa9ff650084051bb605fc7be9c862568 Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Wed, 6 Mar 2024 15:06:41 -0500 Subject: [PATCH 03/20] [WIP] decouple GemmaImpl from CLI args --- examples/hello_world/build/.gitignore | 2 + examples/hello_world/run.cc | 9 ++-- gemma.cc | 77 ++++++++++++++++----------- gemma.h | 24 ++++++--- run.cc | 8 +-- util/app.h | 6 +-- 6 files changed, 78 insertions(+), 48 deletions(-) diff --git a/examples/hello_world/build/.gitignore b/examples/hello_world/build/.gitignore index e69de29..d6b7ef3 100644 --- a/examples/hello_world/build/.gitignore +++ b/examples/hello_world/build/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index a017e22..1a4beed 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -12,6 +12,9 @@ // copybara:import_next_line:gemma_cpp #include "util/args.h" // HasHelp // copybara:end +// copybara:import_next_line:gemma_cpp +#include "configs.h" +// copybara:end #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/highway.h" @@ -35,17 +38,13 @@ int main(int argc, char** argv) { hwy::ThreadPool pool(app.num_threads); hwy::ThreadPool inner_pool(0); gcpp::Gemma model(loader, pool); - - std::vector tokens = tokenize("Hello, how are you?", model.Tokenizer()); - std::mt19937 gen; std::random_device rd; gen.seed(rd()); + std::vector tokens = tokenize("Hello, how are you?", model.Tokenizer()); size_t ntokens = tokens.size(); - size_t pos = 0; - auto stream_token = [&pos, &gen, &ntokens, tokenizer = &model.Tokenizer()](int token, float) { ++pos; if (pos < ntokens) { diff --git a/gemma.cc b/gemma.cc index 9f1e4a0..add3721 100644 --- a/gemma.cc +++ b/gemma.cc @@ -19,18 +19,18 @@ // which we pass the filename via macro 'argument'. #undef HWY_TARGET_INCLUDE #define HWY_TARGET_INCLUDE "gemma.cc" // NOLINT -#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/foreach_target.h" // IWYU pragma: keep // Must come after foreach_target.h to avoid redefinition errors. // copybara:import_next_line:gemma_cpp #include "compression/compress-inl.h" // copybara:import_next_line:gemma_cpp #include "ops.h" // copybara:import_next_line:gemma_cpp -#include "util/args.h" // Path #include "hwy/contrib/matvec/matvec-inl.h" #include "hwy/highway.h" #include "hwy/profiler.h" #include "hwy/timer.h" +#include "util/args.h" // Path // Non-SIMD includes and types. Note that HWY_ONCE is only true on the last // compile pass, whereas we want this defined in the first. @@ -231,9 +231,8 @@ struct Activations { struct GemmaInterface { virtual ~GemmaInterface() = default; - virtual const sentencepiece::SentencePieceProcessor& Tokenizer() const = 0; + virtual const sentencepiece::SentencePieceProcessor* Tokenizer() const = 0; - // TODO: group pool/callbacks into struct virtual void Generate(const InferenceArgs& args, const std::vector& prompt, size_t start_pos, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, @@ -244,7 +243,10 @@ struct GemmaInterface { template struct GemmaImpl : public GemmaInterface { - GemmaImpl(const LoaderArgs& args, hwy::ThreadPool& pool); + GemmaImpl( // const LoaderArgs& args, + std::unique_ptr& tokenizer, + hwy::AlignedFreeUniquePtr& compressed_weights, + hwy::ThreadPool& pool); ~GemmaImpl() { using CWeights = CompressedWeights; @@ -252,8 +254,8 @@ struct GemmaImpl : public GemmaInterface { c_weights->c_layer_ptrs.~CompressedLayerPointers(); } - const sentencepiece::SentencePieceProcessor& Tokenizer() const { - return tokenizer; + const sentencepiece::SentencePieceProcessor* Tokenizer() const { + return tokenizer.get(); } void Generate(const InferenceArgs& args, const std::vector& prompt, @@ -261,9 +263,8 @@ struct GemmaImpl : public GemmaInterface { hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937&, int verbosity); - sentencepiece::SentencePieceProcessor tokenizer; + std::unique_ptr tokenizer; - // CompressedWeights hwy::AlignedFreeUniquePtr compressed_weights; hwy::AlignedUniquePtr> prefill; hwy::AlignedUniquePtr> state; @@ -495,7 +496,8 @@ void Transformer(int token, size_t pos, } template -void GenerateImpl(GemmaImpl& gemma, const InferenceArgs& args, +void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, + size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t pos, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, @@ -549,7 +551,7 @@ void GenerateImpl(GemmaImpl& gemma, const InferenceArgs& args, // should be available as observable state for frontend code to handle I/O. double prefill_end = hwy::platform::Now(); const double prefill_tok_sec = pos_offset / (prefill_end - prefill_start); - std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]\n"; + std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]"; } double gen_start = hwy::platform::Now(); @@ -558,10 +560,10 @@ void GenerateImpl(GemmaImpl& gemma, const InferenceArgs& args, if (verbosity >= 2) { // Provide usage warnings if max_new_tokens is out of range. - if (args.max_generated_tokens > args.max_tokens) { + if (max_generated_tokens > max_tokens) { std::cout << "Warning: max_new_tokens should be <= max_tokens" << std::endl; - } else if ((prompt.size() + args.max_generated_tokens) > args.max_tokens) { + } else if ((prompt.size() + max_generated_tokens) > max_tokens) { std::cout << "Warning: Prompt size + max_new_tokens exceeds max_tokens." << std::endl; } @@ -570,7 +572,7 @@ void GenerateImpl(GemmaImpl& gemma, const InferenceArgs& args, auto pos_gen_start = pos_offset; token = prompt.at(pos_offset); size_t generate_pos = 0; - for (; pos < args.max_tokens && generate_pos < args.max_generated_tokens; + for (; pos < max_tokens && generate_pos < max_generated_tokens; ++pos, ++pos_offset, ++generate_pos) { Transformer(token, pos, c_weights, activations, kv_cache, pool, inner_pool); float* final_activation = activations.x.data(); @@ -583,7 +585,7 @@ void GenerateImpl(GemmaImpl& gemma, const InferenceArgs& args, // Barrier: must have all logits so we can subtract max. Softmax(activations.logits.data(), kVocabSize); token = SampleTopK(activations.logits.data(), kVocabSize, gen, - args.temperature, accept_token); + temperature, accept_token); } if (!stream_token(token, activations.logits[token])) { token = EOS_ID; @@ -593,7 +595,7 @@ void GenerateImpl(GemmaImpl& gemma, const InferenceArgs& args, double gen_end = hwy::platform::Now(); const double gen_tok_sec = (pos_offset - pos_gen_start) / (gen_end - gen_start); - std::cout << "\n[ Generation tokens / sec = " << gen_tok_sec << " ]\n"; + std::cout << "[ Generation tokens / sec = " << gen_tok_sec << " ]\n"; } break; } @@ -605,8 +607,9 @@ void Generate2B(GemmaImpl& gemma, const InferenceArgs& args, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) { - GenerateImpl(gemma, args, prompt, start_pos, pool, inner_pool, stream_token, - accept_token, gen, verbosity); + GenerateImpl(gemma, args.max_tokens, args.max_generated_tokens, + args.temperature, prompt, start_pos, pool, inner_pool, + stream_token, accept_token, gen, verbosity); } void Generate7B(GemmaImpl& gemma, const InferenceArgs& args, @@ -614,8 +617,9 @@ void Generate7B(GemmaImpl& gemma, const InferenceArgs& args, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) { - GenerateImpl(gemma, args, prompt, start_pos, pool, inner_pool, stream_token, - accept_token, gen, verbosity); + GenerateImpl(gemma, args.max_tokens, args.max_generated_tokens, + args.temperature, prompt, start_pos, pool, inner_pool, + stream_token, accept_token, gen, verbosity); } // Calls func(name, float*, CompressedArray&) for each tensor. float* is null @@ -729,17 +733,22 @@ KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len) { } template -GemmaImpl::GemmaImpl(const LoaderArgs& args, hwy::ThreadPool& pool) - : compressed_weights( - HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool)), +GemmaImpl::GemmaImpl( + std::unique_ptr& tokenizer, + hwy::AlignedFreeUniquePtr& compressed_weights, + hwy::ThreadPool& pool) + // GemmaImpl::GemmaImpl(const LoaderArgs& args, hwy::ThreadPool& + // pool) + : compressed_weights(std::move(compressed_weights)), + // HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool)), prefill(hwy::MakeUniqueAligned>()), state(hwy::MakeUniqueAligned>()), kv_cache( CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim, - Config::kSeqLen)) { - PROFILER_ZONE("Startup.tokenizer"); - - HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok()); + Config::kSeqLen)), + tokenizer(std::move(tokenizer)) { + // PROFILER_ZONE("Startup.tokenizer"); + // HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok()); } template <> @@ -770,12 +779,20 @@ void GemmaImpl::Generate(const InferenceArgs& args, Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) { const Model model_type = args.ModelType(); model_training = args.ModelTraining(); + PROFILER_ZONE("Startup.tokenizer"); + std::unique_ptr tokenizer = + std::make_unique(); + HWY_ASSERT(tokenizer->Load(args.tokenizer.path).ok()); + auto compressed_weights = + HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool); switch (model_type) { case Model::GEMMA_2B: - impl_.reset(new GemmaImpl(args, pool)); + impl_.reset( + new GemmaImpl(tokenizer, compressed_weights, pool)); break; case Model::GEMMA_7B: - impl_.reset(new GemmaImpl(args, pool)); + impl_.reset( + new GemmaImpl(tokenizer, compressed_weights, pool)); break; default: HWY_ABORT("Model type %d unknown.", static_cast(model_type)); @@ -783,7 +800,7 @@ Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) { } Gemma::~Gemma() = default; // after GemmaInterface is defined -const sentencepiece::SentencePieceProcessor& Gemma::Tokenizer() const { +const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const { return impl_->Tokenizer(); } diff --git a/gemma.h b/gemma.h index 7195bc9..3de9f0e 100644 --- a/gemma.h +++ b/gemma.h @@ -64,6 +64,15 @@ struct KVCache { enum class Model { GEMMA_2B, GEMMA_7B }; enum class ModelTraining { GEMMA_IT, GEMMA_PT }; +// TODO: incorporate +struct InferenceParams { + Model model; + ModelTraining model_training; + size_t max_generated_tokens; + size_t max_tokens; + float temperature; +}; + struct LoaderArgs : public ArgsBase { LoaderArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } @@ -129,9 +138,9 @@ struct LoaderArgs : public ArgsBase { "file if " "the compressed weights file does not exist.\n Required argument."); visitor(model_type, "model", std::string(), - "Model type\n 2b-it (2B parameters, instruction-tuned)\n " - "2b-pt (2B parameters, pretrained)\n 7b-it (7B parameters " - "instruction-tuned)\n 7b-pt (7B parameters, pretrained)\n" + "Model type\n 2b-it = 2B parameters, instruction-tuned\n " + "2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters " + "instruction-tuned\n 7b-pt = 7B parameters, pretrained\n" " Required argument."); visitor(model, "weights", Path(), "Path name of model weights (.sbs) file. Only required if " @@ -147,7 +156,10 @@ struct Gemma { Gemma(const LoaderArgs& args, hwy::ThreadPool& pool); ~Gemma(); // must be defined after GemmaInterface's dtor is defined. - const sentencepiece::SentencePieceProcessor& Tokenizer() const; + // TODO: cleanup + // const sentencepiece::SentencePieceProcessor& Tokenizer() const; + // const std::unique_ptr Tokenizer() const; + const sentencepiece::SentencePieceProcessor* Tokenizer() const; std::unique_ptr impl_; gcpp::ModelTraining model_training; @@ -192,8 +204,8 @@ struct InferenceArgs : public ArgsBase { visitor(deterministic, "deterministic", false, "Make top-k sampling deterministic", 2); visitor(multiturn, "multiturn", false, - "Multiturn mode (if 0, this clears the KV cache after every " - "interaction without quitting)\n Default : 0 (conversation " + "Multiturn mode\n 0 = clear KV cache after every " + "interaction\n 1 = continue KV cache after every interaction\n Default : 0 (conversation " "resets every turn)"); } }; diff --git a/run.cc b/run.cc index 507979d..50b9a24 100644 --- a/run.cc +++ b/run.cc @@ -115,7 +115,7 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, // callback function invoked for each generated token. auto stream_token = [&abs_pos, ¤t_pos, &args, &gen, &prompt_size, - tokenizer = &model.Tokenizer(), + tokenizer = model.Tokenizer(), verbosity](int token, float) { ++abs_pos; ++current_pos; @@ -129,7 +129,7 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, } } if (verbosity >= 2) { - std::cout << "\n[ End ]" << std::endl; + std::cout << "\n[ End ]\n"; } } else { std::string token_text; @@ -142,7 +142,6 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, std::cout << std::endl << std::endl; } } - // TODO(austinvhuang): is explicit space necessary? std::cout << token_text << std::flush; } return true; @@ -191,7 +190,8 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, } } - HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt).ok()); + // HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt).ok()); + HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt).ok()); // For both pre-trained and instruction-tuned models: prepend "" token // if needed. diff --git a/util/app.h b/util/app.h index 7f926a5..754b2fb 100644 --- a/util/app.h +++ b/util/app.h @@ -79,9 +79,9 @@ class AppArgs : public ArgsBase { template void ForEach(const Visitor& visitor) { visitor(verbosity, "verbosity", 1, - "Show verbose developer information\n 0 = only print generation " - "output\n 1 = standard user-facing terminal ui\n 2 = show " - "developer/debug info).\n Default = 1.", + "Show verbose developer information\n 0 = only print generation " + "output\n 1 = standard user-facing terminal ui\n 2 = show " + "developer/debug info).\n Default = 1.", 2); visitor(num_threads, "num_threads", kDefaultNumThreads, // see ChooseNumThreads From 0f6a4b49d5af7e85d4e2ec5391a68aae343cc608 Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Wed, 6 Mar 2024 15:34:11 -0500 Subject: [PATCH 04/20] [WIP] quality tweaks - for constants, defer float cast and use double for intermediate computations, add `model` to EOT token --- gemma.cc | 6 +++--- run.cc | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/gemma.cc b/gemma.cc index 195b177..dcd87ed 100644 --- a/gemma.cc +++ b/gemma.cc @@ -295,7 +295,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, static constexpr size_t kModelDim = gcpp::Activations::kModelDim; static constexpr size_t kHeads = TConfig::kHeads; - const float kQueryScale = 1.0 / sqrtf(static_cast(kQKVDim)); + static const float kQueryScale = static_cast(1.0 / sqrt(static_cast(kQKVDim))); pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { // linear projections to QKV @@ -418,7 +418,7 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, hwy::ThreadPool& inner_pool) { PROFILER_ZONE("Gen.Prefill\\Att\\FFW"); static constexpr size_t kModelDim = TConfig::kModelDim; - static const float kEmbScaling = sqrtf(static_cast(kModelDim)); + static const float kEmbScaling = static_cast(sqrt(static_cast(kModelDim))); pool.Run( 0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR { @@ -473,7 +473,7 @@ void Transformer(int token, size_t pos, static constexpr size_t kLayers = TConfig::kLayers; static constexpr size_t kModelDim = TConfig::kModelDim; - static const float kEmbScaling = sqrtf(static_cast(kModelDim)); + static const float kEmbScaling = static_cast(sqrt(static_cast(kModelDim))); Decompress(c_weights.c_embedder_input_embedding, token * kModelDim, activations.x.data(), kModelDim); diff --git a/run.cc b/run.cc index 50b9a24..71481f9 100644 --- a/run.cc +++ b/run.cc @@ -186,7 +186,7 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, if (abs_pos > 0) { // Prepend "" token if this is a multi-turn dialogue // continuation. - prompt_string = "\n" + prompt_string; + prompt_string = "model\n" + prompt_string; } } From 7042316013d7e2ad06532f420b551802ec114f89 Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Wed, 6 Mar 2024 22:22:59 -0500 Subject: [PATCH 05/20] [WIP] update GemmaInterface, Gemma, and Generate input parameter specs to remove InferenceArgs. TODO: update hello_world example after git commit hash is available for fetching --- examples/hello_world/run.cc | 2 +- gemma.cc | 82 +++++++++++++++++++------------------ gemma.h | 14 +++---- run.cc | 5 ++- 4 files changed, 53 insertions(+), 50 deletions(-) diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index 1a4beed..67953dc 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -66,7 +66,7 @@ int main(int argc, char** argv) { inference.multiturn = false; GenerateGemma( - model, inference, tokens, 0, pool, inner_pool, stream_token, + model, /*max_tokens=*/2048, /*max_generated_tokens=*/1024, /*temperature=*/1.0, tokens, 0, pool, inner_pool, stream_token, [](int) {return true;}, gen, 0); std::cout << std::endl; diff --git a/gemma.cc b/gemma.cc index dcd87ed..d49d0df 100644 --- a/gemma.cc +++ b/gemma.cc @@ -233,9 +233,10 @@ struct GemmaInterface { virtual const sentencepiece::SentencePieceProcessor* Tokenizer() const = 0; - virtual void Generate(const InferenceArgs& args, - const std::vector& prompt, size_t start_pos, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, + virtual void Generate(size_t max_tokens, size_t max_generated_tokens, + float temperature, const std::vector& prompt, + size_t start_pos, hwy::ThreadPool& pool, + hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) = 0; @@ -258,7 +259,8 @@ struct GemmaImpl : public GemmaInterface { return tokenizer.get(); } - void Generate(const InferenceArgs& args, const std::vector& prompt, + void Generate(size_t max_tokens, size_t max_generated_tokens, + float temperature, const std::vector& prompt, size_t start_pos, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937&, int verbosity); @@ -295,7 +297,8 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer, static constexpr size_t kModelDim = gcpp::Activations::kModelDim; static constexpr size_t kHeads = TConfig::kHeads; - static const float kQueryScale = static_cast(1.0 / sqrt(static_cast(kQKVDim))); + static const float kQueryScale = + static_cast(1.0 / sqrt(static_cast(kQKVDim))); pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { // linear projections to QKV @@ -418,7 +421,8 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, hwy::ThreadPool& inner_pool) { PROFILER_ZONE("Gen.Prefill\\Att\\FFW"); static constexpr size_t kModelDim = TConfig::kModelDim; - static const float kEmbScaling = static_cast(sqrt(static_cast(kModelDim))); + static const float kEmbScaling = + static_cast(sqrt(static_cast(kModelDim))); pool.Run( 0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR { @@ -473,7 +477,8 @@ void Transformer(int token, size_t pos, static constexpr size_t kLayers = TConfig::kLayers; static constexpr size_t kModelDim = TConfig::kModelDim; - static const float kEmbScaling = static_cast(sqrt(static_cast(kModelDim))); + static const float kEmbScaling = + static_cast(sqrt(static_cast(kModelDim))); Decompress(c_weights.c_embedder_input_embedding, token * kModelDim, activations.x.data(), kModelDim); @@ -604,24 +609,26 @@ void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, } } -void Generate2B(GemmaImpl& gemma, const InferenceArgs& args, +void Generate2B(GemmaImpl& gemma, size_t max_tokens, + size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) { - GenerateImpl(gemma, args.max_tokens, args.max_generated_tokens, - args.temperature, prompt, start_pos, pool, inner_pool, - stream_token, accept_token, gen, verbosity); + GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, + start_pos, pool, inner_pool, stream_token, accept_token, gen, + verbosity); } -void Generate7B(GemmaImpl& gemma, const InferenceArgs& args, +void Generate7B(GemmaImpl& gemma, size_t max_tokens, + size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) { - GenerateImpl(gemma, args.max_tokens, args.max_generated_tokens, - args.temperature, prompt, start_pos, pool, inner_pool, - stream_token, accept_token, gen, verbosity); + GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, + start_pos, pool, inner_pool, stream_token, accept_token, gen, + verbosity); } // Calls func(name, float*, CompressedArray&) for each tensor. float* is null @@ -755,28 +762,24 @@ GemmaImpl::GemmaImpl( } template <> -void GemmaImpl::Generate(const InferenceArgs& args, - const std::vector& prompt, - size_t start_pos, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, - const AcceptFunc& accept_token, - std::mt19937& gen, int verbosity) { +void GemmaImpl::Generate( + size_t max_tokens, size_t max_generated_tokens, float temperature, + const std::vector& prompt, size_t start_pos, hwy::ThreadPool& pool, + hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, + const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) { HWY_DYNAMIC_DISPATCH(Generate2B) - (*this, args, prompt, start_pos, pool, inner_pool, stream_token, accept_token, - gen, verbosity); + (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, + pool, inner_pool, stream_token, accept_token, gen, verbosity); } template <> -void GemmaImpl::Generate(const InferenceArgs& args, - const std::vector& prompt, - size_t start_pos, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, - const AcceptFunc& accept_token, - std::mt19937& gen, int verbosity) { +void GemmaImpl::Generate( + size_t max_tokens, size_t max_generated_tokens, float temperature, + const std::vector& prompt, size_t start_pos, hwy::ThreadPool& pool, + hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, + const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) { HWY_DYNAMIC_DISPATCH(Generate7B) - (*this, args, prompt, start_pos, pool, inner_pool, stream_token, accept_token, - gen, verbosity); + (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, + pool, inner_pool, stream_token, accept_token, gen, verbosity); } Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) { @@ -807,15 +810,16 @@ const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const { return impl_->Tokenizer(); } -void GenerateGemma(Gemma& gemma, const InferenceArgs& args, - const std::vector& prompt, size_t start_pos, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, +void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, + float temperature, const std::vector& prompt, + size_t start_pos, hwy::ThreadPool& pool, + hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) { pool.SetWaitMode(hwy::PoolWaitMode::kSpin); - gemma.impl_->Generate(args, prompt, start_pos, pool, inner_pool, stream_token, - accept_token, gen, verbosity); + gemma.impl_->Generate(max_tokens, max_generated_tokens, + temperature, prompt, start_pos, pool, inner_pool, + stream_token, accept_token, gen, verbosity); pool.SetWaitMode(hwy::PoolWaitMode::kBlock); } diff --git a/gemma.h b/gemma.h index 3de9f0e..5a2f2b0 100644 --- a/gemma.h +++ b/gemma.h @@ -156,9 +156,6 @@ struct Gemma { Gemma(const LoaderArgs& args, hwy::ThreadPool& pool); ~Gemma(); // must be defined after GemmaInterface's dtor is defined. - // TODO: cleanup - // const sentencepiece::SentencePieceProcessor& Tokenizer() const; - // const std::unique_ptr Tokenizer() const; const sentencepiece::SentencePieceProcessor* Tokenizer() const; std::unique_ptr impl_; @@ -205,15 +202,16 @@ struct InferenceArgs : public ArgsBase { "Make top-k sampling deterministic", 2); visitor(multiturn, "multiturn", false, "Multiturn mode\n 0 = clear KV cache after every " - "interaction\n 1 = continue KV cache after every interaction\n Default : 0 (conversation " + "interaction\n 1 = continue KV cache after every interaction\n " + " Default : 0 (conversation " "resets every turn)"); } }; -void GenerateGemma(Gemma& gemma, const InferenceArgs& args, - const std::vector& prompt, size_t start_pos, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, +void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, + float temperature, const std::vector& prompt, + size_t start_pos, hwy::ThreadPool& pool, + hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& g, int verbosity); diff --git a/run.cc b/run.cc index 71481f9..eac5f9e 100644 --- a/run.cc +++ b/run.cc @@ -204,8 +204,9 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, std::cerr << std::endl << "[ Reading prompt ] " << std::flush; const double time_start = hwy::platform::Now(); - GenerateGemma(model, args, prompt, abs_pos, pool, inner_pool, stream_token, - accept_token, gen, verbosity); + GenerateGemma(model, args.max_tokens, args.max_generated_tokens, + args.temperature, prompt, abs_pos, pool, inner_pool, + stream_token, accept_token, gen, verbosity); const double time_end = hwy::platform::Now(); const double tok_sec = current_pos / (time_end - time_start); if (verbosity >= 2) { From e781007836ec034236e90cc4d313d0a8c481bce6 Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Wed, 6 Mar 2024 23:21:13 -0500 Subject: [PATCH 06/20] [WIP] Remove InferenceArgs from hello_world example, fix ordering of LoaderArgs validation, revert ReplGemma EOT token behavior --- examples/hello_world/CMakeLists.txt | 2 +- examples/hello_world/run.cc | 37 +++++++++++++---------------- gemma.cc | 8 ++++--- gemma.h | 8 +++---- run.cc | 2 +- 5 files changed, 28 insertions(+), 29 deletions(-) diff --git a/examples/hello_world/CMakeLists.txt b/examples/hello_world/CMakeLists.txt index 69e4e98..63c8d7b 100644 --- a/examples/hello_world/CMakeLists.txt +++ b/examples/hello_world/CMakeLists.txt @@ -27,7 +27,7 @@ FetchContent_MakeAvailable(highway) FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) FetchContent_MakeAvailable(sentencepiece) -FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG 808dbdc42b216c3ac1f1c40dfa638bcff24bbd2b) +FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG 7042316013d7e2ad06532f420b551802ec114f89) FetchContent_MakeAvailable(gemma) if(NOT CMAKE_BUILD_TYPE) diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index 67953dc..227c91c 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -7,9 +7,6 @@ #include "gemma.h" // Gemma // copybara:end // copybara:import_next_line:gemma_cpp -#include "util/app.h" -// copybara:end -// copybara:import_next_line:gemma_cpp #include "util/args.h" // HasHelp // copybara:end // copybara:import_next_line:gemma_cpp @@ -22,30 +19,35 @@ #include "hwy/profiler.h" #include "hwy/timer.h" -std::vector tokenize(std::string prompt_string, const sentencepiece::SentencePieceProcessor& tokenizer) { +std::vector tokenize( + std::string prompt_string, + const sentencepiece::SentencePieceProcessor* tokenizer) { prompt_string = "user\n" + prompt_string + "\nmodel\n"; std::vector tokens; - HWY_ASSERT(tokenizer.Encode(prompt_string, &tokens).ok()); - tokens.insert(tokens.begin(), 2); // BOS token + HWY_ASSERT(tokenizer->Encode(prompt_string, &tokens).ok()); + tokens.insert(tokens.begin(), 2); // BOS token return tokens; } int main(int argc, char** argv) { - gcpp::InferenceArgs inference(argc, argv); gcpp::LoaderArgs loader(argc, argv); - gcpp::AppArgs app(argc, argv); - hwy::ThreadPool pool(app.num_threads); + // A rough heuristic for a reasonable number of threads given hardware + // concurrency estimate + size_t num_threads = static_cast(std::clamp( + static_cast(std::thread::hardware_concurrency()) - 2, 1, 18)); + hwy::ThreadPool pool(num_threads); hwy::ThreadPool inner_pool(0); gcpp::Gemma model(loader, pool); std::mt19937 gen; std::random_device rd; gen.seed(rd()); - - std::vector tokens = tokenize("Hello, how are you?", model.Tokenizer()); + std::vector tokens = + tokenize("Write a greeting to the world.", model.Tokenizer()); size_t ntokens = tokens.size(); size_t pos = 0; - auto stream_token = [&pos, &gen, &ntokens, tokenizer = &model.Tokenizer()](int token, float) { + auto stream_token = [&pos, &gen, &ntokens, tokenizer = model.Tokenizer()]( + int token, float) { ++pos; if (pos < ntokens) { // print feedback @@ -60,14 +62,9 @@ int main(int argc, char** argv) { } return true; }; - - inference.temperature = 1.0f; - inference.deterministic = true; - inference.multiturn = false; - GenerateGemma( - model, /*max_tokens=*/2048, /*max_generated_tokens=*/1024, /*temperature=*/1.0, tokens, 0, pool, inner_pool, stream_token, - [](int) {return true;}, gen, 0); - + model, /*max_tokens=*/2048, /*max_generated_tokens=*/1024, + /*temperature=*/1.0, tokens, 0, pool, inner_pool, stream_token, + [](int) { return true; }, gen, 0); std::cout << std::endl; } diff --git a/gemma.cc b/gemma.cc index d49d0df..bbd86c3 100644 --- a/gemma.cc +++ b/gemma.cc @@ -782,6 +782,8 @@ void GemmaImpl::Generate( pool, inner_pool, stream_token, accept_token, gen, verbosity); } +// TODO: Make Gemma type independent of LoaderArgs, create a factory function +// that takes LoaderArgs and creates a Gemma instance. Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) { const Model model_type = args.ModelType(); model_training = args.ModelTraining(); @@ -817,9 +819,9 @@ void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) { pool.SetWaitMode(hwy::PoolWaitMode::kSpin); - gemma.impl_->Generate(max_tokens, max_generated_tokens, - temperature, prompt, start_pos, pool, inner_pool, - stream_token, accept_token, gen, verbosity); + gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt, + start_pos, pool, inner_pool, stream_token, accept_token, + gen, verbosity); pool.SetWaitMode(hwy::PoolWaitMode::kBlock); } diff --git a/gemma.h b/gemma.h index 5a2f2b0..3528b50 100644 --- a/gemma.h +++ b/gemma.h @@ -104,6 +104,10 @@ struct LoaderArgs : public ArgsBase { // Returns error string or nullptr if OK. const char* Validate() const { const std::string model_type_lc = ToLower(model_type); + if (model_type.empty()) { + return "Missing --model flag, need to specify either 2b-pt, 7b-pt, " + "2b-it, or 7b-it."; + } if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" && model_type_lc != "2b-it" && model_type_lc != "7b-it") { return "Model type must be 2b-pt, 7b-pt, 2b-it, or " @@ -112,10 +116,6 @@ struct LoaderArgs : public ArgsBase { if (tokenizer.path.empty()) { return "Missing --tokenizer flag, a file for the tokenizer is required."; } - if (model_type.empty()) { - return "Missing --model flag, need to specify either 2b-pt, 7b-pt, " - "2b-it, or 7b-it."; - } if (cache.path.empty()) { return "Missing --compressed_weights flag, a file for the compressed " "model."; diff --git a/run.cc b/run.cc index eac5f9e..4c4f132 100644 --- a/run.cc +++ b/run.cc @@ -186,7 +186,7 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, if (abs_pos > 0) { // Prepend "" token if this is a multi-turn dialogue // continuation. - prompt_string = "model\n" + prompt_string; + prompt_string = "\n" + prompt_string; } } From 49e654258dc556f694272a3b7dfe37bc449211fd Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Thu, 7 Mar 2024 01:04:25 -0500 Subject: [PATCH 07/20] [WIP] clean up hello_world #includes and CMakeLists.txt --- examples/hello_world/CMakeLists.txt | 9 +-------- examples/hello_world/run.cc | 15 ++------------- 2 files changed, 3 insertions(+), 21 deletions(-) diff --git a/examples/hello_world/CMakeLists.txt b/examples/hello_world/CMakeLists.txt index 63c8d7b..088af84 100644 --- a/examples/hello_world/CMakeLists.txt +++ b/examples/hello_world/CMakeLists.txt @@ -13,21 +13,16 @@ # limitations under the License. cmake_minimum_required(VERSION 3.11) - project(hello_world) - set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) include(FetchContent) - FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f) FetchContent_MakeAvailable(highway) - FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) FetchContent_MakeAvailable(sentencepiece) - -FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG 7042316013d7e2ad06532f420b551802ec114f89) +FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG e781007836ec034236e90cc4d313d0a8c481bce6) FetchContent_MakeAvailable(gemma) if(NOT CMAKE_BUILD_TYPE) @@ -35,10 +30,8 @@ if(NOT CMAKE_BUILD_TYPE) endif() add_executable(hello_world run.cc) -target_sources(hello_world PRIVATE ${SOURCES}) set_property(TARGET hello_world PROPERTY CXX_STANDARD 17) target_link_libraries(hello_world hwy hwy_contrib sentencepiece libgemma) -target_include_directories(hello_world PRIVATE ./) FetchContent_GetProperties(sentencepiece) target_include_directories(hello_world PRIVATE ${sentencepiece_SOURCE_DIR}) target_compile_definitions(hello_world PRIVATE $<$:_CRT_SECURE_NO_WARNINGS NOMINMAX>) diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index 227c91c..a972154 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -1,23 +1,12 @@ #include // copybara:import_next_line:gemma_cpp -#include "compression/compress.h" +#include "gemma.h" // copybara:end // copybara:import_next_line:gemma_cpp -#include "gemma.h" // Gemma +#include "util/args.h" // copybara:end -// copybara:import_next_line:gemma_cpp -#include "util/args.h" // HasHelp -// copybara:end -// copybara:import_next_line:gemma_cpp -#include "configs.h" -// copybara:end -#include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" -#include "hwy/highway.h" -#include "hwy/per_target.h" -#include "hwy/profiler.h" -#include "hwy/timer.h" std::vector tokenize( std::string prompt_string, From 6c0388e0495083086f724be98965a016807ae536 Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Thu, 7 Mar 2024 01:14:07 -0500 Subject: [PATCH 08/20] [WIP] refine Runtime struct definition --- gemma.h | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/gemma.h b/gemma.h index 3528b50..f6361e1 100644 --- a/gemma.h +++ b/gemma.h @@ -64,13 +64,14 @@ struct KVCache { enum class Model { GEMMA_2B, GEMMA_7B }; enum class ModelTraining { GEMMA_IT, GEMMA_PT }; -// TODO: incorporate -struct InferenceParams { - Model model; +// TODO: Incorporate this +struct Runtime { + Model model_type; ModelTraining model_training; - size_t max_generated_tokens; size_t max_tokens; + size_t max_generated_tokens; float temperature; + std::mt19937 gen; }; struct LoaderArgs : public ArgsBase { @@ -212,7 +213,7 @@ void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, - const AcceptFunc& accept_token, std::mt19937& g, + const AcceptFunc& accept_token, std::mt19937& gen, int verbosity); constexpr int EOS_ID = 1; From b841612e8cb56fcbb4cff8a73e196460520ede31 Mon Sep 17 00:00:00 2001 From: RangerUFO Date: Tue, 5 Mar 2024 17:50:24 +0800 Subject: [PATCH 09/20] Separate KV cache from GemmaImpl --- gemma.cc | 89 +++++++++++++++++++++++++++++++------------------------- gemma.h | 3 +- run.cc | 13 +++++---- 3 files changed, 59 insertions(+), 46 deletions(-) diff --git a/gemma.cc b/gemma.cc index bbd86c3..31b3c38 100644 --- a/gemma.cc +++ b/gemma.cc @@ -231,12 +231,13 @@ struct Activations { struct GemmaInterface { virtual ~GemmaInterface() = default; + virtual KVCache CreateKVCache() const = 0; virtual const sentencepiece::SentencePieceProcessor* Tokenizer() const = 0; virtual void Generate(size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, - size_t start_pos, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, + size_t start_pos, KVCache& kv_cache, + hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) = 0; @@ -255,22 +256,24 @@ struct GemmaImpl : public GemmaInterface { c_weights->c_layer_ptrs.~CompressedLayerPointers(); } - const sentencepiece::SentencePieceProcessor* Tokenizer() const { + KVCache CreateKVCache() const override; + + const sentencepiece::SentencePieceProcessor* Tokenizer() const override { return tokenizer.get(); } void Generate(size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, - size_t start_pos, hwy::ThreadPool& pool, + size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, - const AcceptFunc& accept_token, std::mt19937&, int verbosity); + const AcceptFunc& accept_token, std::mt19937&, + int verbosity) override; std::unique_ptr tokenizer; hwy::AlignedFreeUniquePtr compressed_weights; hwy::AlignedUniquePtr> prefill; hwy::AlignedUniquePtr> state; - KVCache kv_cache; }; } // namespace gcpp @@ -503,7 +506,7 @@ void Transformer(int token, size_t pos, template void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, size_t max_generated_tokens, float temperature, - const std::vector& prompt, size_t pos, + const std::vector& prompt, size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, @@ -517,7 +520,6 @@ void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, const CompressedWeights& c_weights = *reinterpret_cast*>( gemma.compressed_weights.get()); - KVCache& kv_cache = gemma.kv_cache; int token; // pos indexes the KV cache. In the first turn of a chat, pos = 0. @@ -612,23 +614,25 @@ void GenerateImpl(GemmaImpl& gemma, size_t max_tokens, void Generate2B(GemmaImpl& gemma, size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, const AcceptFunc& accept_token, - std::mt19937& gen, int verbosity) { + KVCache& kv_cache, hwy::ThreadPool& pool, + hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, + const AcceptFunc& accept_token, std::mt19937& gen, + int verbosity) { GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, - start_pos, pool, inner_pool, stream_token, accept_token, gen, - verbosity); + start_pos, kv_cache, pool, inner_pool, stream_token, + accept_token, gen, verbosity); } void Generate7B(GemmaImpl& gemma, size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, - hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, - const StreamFunc& stream_token, const AcceptFunc& accept_token, - std::mt19937& gen, int verbosity) { + KVCache& kv_cache, hwy::ThreadPool& pool, + hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, + const AcceptFunc& accept_token, std::mt19937& gen, + int verbosity) { GenerateImpl(gemma, max_tokens, max_generated_tokens, temperature, prompt, - start_pos, pool, inner_pool, stream_token, accept_token, gen, - verbosity); + start_pos, kv_cache, pool, inner_pool, stream_token, + accept_token, gen, verbosity); } // Calls func(name, float*, CompressedArray&) for each tensor. float* is null @@ -735,13 +739,6 @@ HWY_EXPORT(GetCompressedWeightsT); HWY_EXPORT(Generate2B); HWY_EXPORT(Generate7B); -KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len) { - KVCache kv_cache = {}; - kv_cache.key_cache = hwy::AllocateAligned(seq_len * size_cache_pos); - kv_cache.value_cache = hwy::AllocateAligned(seq_len * size_cache_pos); - return kv_cache; -} - template GemmaImpl::GemmaImpl( std::unique_ptr& tokenizer, @@ -753,33 +750,43 @@ GemmaImpl::GemmaImpl( // HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool)), prefill(hwy::MakeUniqueAligned>()), state(hwy::MakeUniqueAligned>()), - kv_cache( - CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim, - Config::kSeqLen)), tokenizer(std::move(tokenizer)) { // PROFILER_ZONE("Startup.tokenizer"); // HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok()); } +template +KVCache GemmaImpl::CreateKVCache() const { + constexpr const size_t size_cache_pos = Config::kLayers * Config::kKVHeads * + Config::kQKVDim; + constexpr const size_t seq_len = Config::kSeqLen; + KVCache kv_cache = {}; + kv_cache.key_cache = hwy::AllocateAligned(seq_len * size_cache_pos); + kv_cache.value_cache = hwy::AllocateAligned(seq_len * size_cache_pos); + return kv_cache; +} + template <> void GemmaImpl::Generate( size_t max_tokens, size_t max_generated_tokens, float temperature, - const std::vector& prompt, size_t start_pos, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, - const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) { + const std::vector& prompt, size_t start_pos, KVCache& kv_cache, + hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, + const StreamFunc& stream_token, const AcceptFunc& accept_token, + std::mt19937& gen, int verbosity) { HWY_DYNAMIC_DISPATCH(Generate2B) (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, - pool, inner_pool, stream_token, accept_token, gen, verbosity); + kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); } template <> void GemmaImpl::Generate( size_t max_tokens, size_t max_generated_tokens, float temperature, - const std::vector& prompt, size_t start_pos, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, - const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) { + const std::vector& prompt, size_t start_pos, KVCache& kv_cache, + hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, + const StreamFunc& stream_token, const AcceptFunc& accept_token, + std::mt19937& gen, int verbosity) { HWY_DYNAMIC_DISPATCH(Generate7B) (*this, max_tokens, max_generated_tokens, temperature, prompt, start_pos, - pool, inner_pool, stream_token, accept_token, gen, verbosity); + kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); } // TODO: Make Gemma type independent of LoaderArgs, create a factory function @@ -808,20 +815,24 @@ Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) { } Gemma::~Gemma() = default; // after GemmaInterface is defined +KVCache Gemma::CreateKVCache() const { + return impl_->CreateKVCache(); +} + const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const { return impl_->Tokenizer(); } void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, - size_t start_pos, hwy::ThreadPool& pool, + size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) { pool.SetWaitMode(hwy::PoolWaitMode::kSpin); gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt, - start_pos, pool, inner_pool, stream_token, accept_token, - gen, verbosity); + start_pos, kv_cache, pool, inner_pool, stream_token, + accept_token, gen, verbosity); pool.SetWaitMode(hwy::PoolWaitMode::kBlock); } diff --git a/gemma.h b/gemma.h index f6361e1..03da2e5 100644 --- a/gemma.h +++ b/gemma.h @@ -157,6 +157,7 @@ struct Gemma { Gemma(const LoaderArgs& args, hwy::ThreadPool& pool); ~Gemma(); // must be defined after GemmaInterface's dtor is defined. + KVCache CreateKVCache() const; const sentencepiece::SentencePieceProcessor* Tokenizer() const; std::unique_ptr impl_; @@ -211,7 +212,7 @@ struct InferenceArgs : public ArgsBase { void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, - size_t start_pos, hwy::ThreadPool& pool, + size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const StreamFunc& stream_token, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity); diff --git a/run.cc b/run.cc index 4c4f132..ff7ed3d 100644 --- a/run.cc +++ b/run.cc @@ -96,10 +96,10 @@ void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference, std::cerr << "\n"; } -void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, - hwy::ThreadPool& inner_pool, const InferenceArgs& args, - int verbosity, const gcpp::AcceptFunc& accept_token, - std::string& eot_line) { +void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache, + hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, + const InferenceArgs& args, int verbosity, + const gcpp::AcceptFunc& accept_token, std::string& eot_line) { PROFILER_ZONE("Gen.misc"); int abs_pos = 0; // absolute token index over all turns int current_pos = 0; // token index within the current turn @@ -205,7 +205,7 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, const double time_start = hwy::platform::Now(); GenerateGemma(model, args.max_tokens, args.max_generated_tokens, - args.temperature, prompt, abs_pos, pool, inner_pool, + args.temperature, prompt, abs_pos, kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); const double time_end = hwy::platform::Now(); const double tok_sec = current_pos / (time_end - time_start); @@ -236,6 +236,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { } gcpp::Gemma model(loader, pool); + auto kv_cache = model.CreateKVCache(); if (const char* error = inference.Validate()) { ShowHelp(loader, inference, app); @@ -273,7 +274,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { } ReplGemma( - model, pool, inner_pool, inference, app.verbosity, + model, kv_cache, pool, inner_pool, inference, app.verbosity, /*accept_token=*/[](int) { return true; }, app.eot_line); } From 170a9b4690482dd1229a0876348cb898b34001f1 Mon Sep 17 00:00:00 2001 From: RangerUFO Date: Thu, 7 Mar 2024 14:08:48 +0800 Subject: [PATCH 10/20] Make `CreateKVCache` a free function rather than a method --- gemma.cc | 42 ++++++++++++++++++++++++------------------ gemma.h | 4 +++- run.cc | 2 +- 3 files changed, 28 insertions(+), 20 deletions(-) diff --git a/gemma.cc b/gemma.cc index 31b3c38..ba6aafa 100644 --- a/gemma.cc +++ b/gemma.cc @@ -231,7 +231,6 @@ struct Activations { struct GemmaInterface { virtual ~GemmaInterface() = default; - virtual KVCache CreateKVCache() const = 0; virtual const sentencepiece::SentencePieceProcessor* Tokenizer() const = 0; virtual void Generate(size_t max_tokens, size_t max_generated_tokens, @@ -243,6 +242,23 @@ struct GemmaInterface { int verbosity) = 0; }; +template +KVCache CreateKVCache() { + return CreateKVCache(Config::kLayers * Config::kKVHeads * Config::kQKVDim, + Config::kSeqLen); +} + +KVCache CreateKVCache(Model type) { + switch (type) { + case Model::GEMMA_2B: + return CreateKVCache(); + case Model::GEMMA_7B: + return CreateKVCache(); + default: + HWY_ABORT("Model type %d unknown.", static_cast(type)); + } +} + template struct GemmaImpl : public GemmaInterface { GemmaImpl( // const LoaderArgs& args, @@ -256,8 +272,6 @@ struct GemmaImpl : public GemmaInterface { c_weights->c_layer_ptrs.~CompressedLayerPointers(); } - KVCache CreateKVCache() const override; - const sentencepiece::SentencePieceProcessor* Tokenizer() const override { return tokenizer.get(); } @@ -739,6 +753,13 @@ HWY_EXPORT(GetCompressedWeightsT); HWY_EXPORT(Generate2B); HWY_EXPORT(Generate7B); +KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len) { + KVCache kv_cache = {}; + kv_cache.key_cache = hwy::AllocateAligned(seq_len * size_cache_pos); + kv_cache.value_cache = hwy::AllocateAligned(seq_len * size_cache_pos); + return kv_cache; +} + template GemmaImpl::GemmaImpl( std::unique_ptr& tokenizer, @@ -755,17 +776,6 @@ GemmaImpl::GemmaImpl( // HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok()); } -template -KVCache GemmaImpl::CreateKVCache() const { - constexpr const size_t size_cache_pos = Config::kLayers * Config::kKVHeads * - Config::kQKVDim; - constexpr const size_t seq_len = Config::kSeqLen; - KVCache kv_cache = {}; - kv_cache.key_cache = hwy::AllocateAligned(seq_len * size_cache_pos); - kv_cache.value_cache = hwy::AllocateAligned(seq_len * size_cache_pos); - return kv_cache; -} - template <> void GemmaImpl::Generate( size_t max_tokens, size_t max_generated_tokens, float temperature, @@ -815,10 +825,6 @@ Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) { } Gemma::~Gemma() = default; // after GemmaInterface is defined -KVCache Gemma::CreateKVCache() const { - return impl_->CreateKVCache(); -} - const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const { return impl_->Tokenizer(); } diff --git a/gemma.h b/gemma.h index 03da2e5..58fd74a 100644 --- a/gemma.h +++ b/gemma.h @@ -157,13 +157,15 @@ struct Gemma { Gemma(const LoaderArgs& args, hwy::ThreadPool& pool); ~Gemma(); // must be defined after GemmaInterface's dtor is defined. - KVCache CreateKVCache() const; const sentencepiece::SentencePieceProcessor* Tokenizer() const; std::unique_ptr impl_; gcpp::ModelTraining model_training; }; +KVCache CreateKVCache(Model type); // convenient workaround for now +KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len); + // StreamFunc is called with (token, probability). For prompt tokens, // probability is 0.0f. using StreamFunc = std::function; diff --git a/run.cc b/run.cc index ff7ed3d..40be63e 100644 --- a/run.cc +++ b/run.cc @@ -236,7 +236,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { } gcpp::Gemma model(loader, pool); - auto kv_cache = model.CreateKVCache(); + auto kv_cache = CreateKVCache(loader.ModelType()); if (const char* error = inference.Validate()) { ShowHelp(loader, inference, app); From b67e28d1a03a97fb2b8b098fc4b12d9372baf36e Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Fri, 8 Mar 2024 00:00:11 -0500 Subject: [PATCH 11/20] [WIP] remove args from GetWeights, GetCompressedWeights --- gemma.cc | 26 ++++++++++++++------------ gemma.h | 46 +++++++++++++++++++++++++--------------------- run.cc | 1 + 3 files changed, 40 insertions(+), 33 deletions(-) diff --git a/gemma.cc b/gemma.cc index ba6aafa..f080dda 100644 --- a/gemma.cc +++ b/gemma.cc @@ -30,6 +30,7 @@ #include "hwy/highway.h" #include "hwy/profiler.h" #include "hwy/timer.h" +#include "util/app.h" // arg types #include "util/args.h" // Path // Non-SIMD includes and types. Note that HWY_ONCE is only true on the last @@ -697,10 +698,10 @@ void ForEachTensor(const Weights* weights, template hwy::AlignedFreeUniquePtr GetCompressedWeights( - const Path& model, const Path& cache, hwy::ThreadPool& pool) { + const Path& weights_path, const Path& cache, hwy::ThreadPool& pool) { PROFILER_ZONE("Startup.LoadCache"); - if (!std::filesystem::exists(model.path) && + if (!std::filesystem::exists(weights_path.path) && !std::filesystem::exists(cache.path)) { HWY_ABORT( "Either the model weights (--weights) or cached compressed weights " @@ -721,7 +722,7 @@ hwy::AlignedFreeUniquePtr GetCompressedWeights( // Get weights, compress, and store in cache. const hwy::AlignedUniquePtr> weights = - LoadWeights(model); + LoadWeights(weights_path); Compressor compressor(pool); ForEachTensor(weights.get(), *c_weights, compressor); compressor.WriteAll(pool, cache.path.c_str()); @@ -731,14 +732,17 @@ hwy::AlignedFreeUniquePtr GetCompressedWeights( // Type-erased because this function is called via a function pointer. hwy::AlignedFreeUniquePtr GetCompressedWeightsT( - const LoaderArgs& args, hwy::ThreadPool& pool) { - switch (args.ModelType()) { + gcpp::Model model, const Path& weights, const Path& compressed_weights, + hwy::ThreadPool& pool) { + switch (model) { case Model::GEMMA_2B: - return GetCompressedWeights(args.model, args.cache, pool); + return GetCompressedWeights(weights, compressed_weights, + pool); case Model::GEMMA_7B: - return GetCompressedWeights(args.model, args.cache, pool); + return GetCompressedWeights(weights, compressed_weights, + pool); default: - HWY_ABORT("Model type %d unknown.", static_cast(args.ModelType())); + HWY_ABORT("Model type %d unknown.", static_cast(model)); } } @@ -799,8 +803,6 @@ void GemmaImpl::Generate( kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); } -// TODO: Make Gemma type independent of LoaderArgs, create a factory function -// that takes LoaderArgs and creates a Gemma instance. Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) { const Model model_type = args.ModelType(); model_training = args.ModelTraining(); @@ -808,8 +810,8 @@ Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) { std::unique_ptr tokenizer = std::make_unique(); HWY_ASSERT(tokenizer->Load(args.tokenizer.path).ok()); - auto compressed_weights = - HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool); + auto compressed_weights = HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)( + args.ModelType(), args.model, args.cache, pool); switch (model_type) { case Model::GEMMA_2B: impl_.reset( diff --git a/gemma.h b/gemma.h index 58fd74a..1a6ca07 100644 --- a/gemma.h +++ b/gemma.h @@ -66,6 +66,9 @@ enum class ModelTraining { GEMMA_IT, GEMMA_PT }; // TODO: Incorporate this struct Runtime { + // TODO: In the future we may fold ModelTraining into Model. + // As we add more variations of model_type, the cartesian set becomes + // unwieldy. Model model_type; ModelTraining model_training; size_t max_tokens; @@ -126,7 +129,7 @@ struct LoaderArgs : public ArgsBase { Path tokenizer; Path model; // uncompressed weights OR - Path cache; // compressed weights + Path cache; // compressed weights (TODO: update name) std::string model_type; template @@ -151,26 +154,6 @@ struct LoaderArgs : public ArgsBase { } }; -struct GemmaInterface; - -struct Gemma { - Gemma(const LoaderArgs& args, hwy::ThreadPool& pool); - ~Gemma(); // must be defined after GemmaInterface's dtor is defined. - - const sentencepiece::SentencePieceProcessor* Tokenizer() const; - - std::unique_ptr impl_; - gcpp::ModelTraining model_training; -}; - -KVCache CreateKVCache(Model type); // convenient workaround for now -KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len); - -// StreamFunc is called with (token, probability). For prompt tokens, -// probability is 0.0f. -using StreamFunc = std::function; -using AcceptFunc = std::function; - struct InferenceArgs : public ArgsBase { InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } @@ -212,6 +195,27 @@ struct InferenceArgs : public ArgsBase { } }; +struct GemmaInterface; + +struct Gemma { + Gemma(const LoaderArgs& args, hwy::ThreadPool& pool); + ~Gemma(); // must be defined after GemmaInterface's dtor is defined. + const sentencepiece::SentencePieceProcessor* Tokenizer() const; + std::unique_ptr impl_; + gcpp::ModelTraining model_training; +}; + +struct LoaderArgs; // forward declaration +void CreateGemma(const LoaderArgs& args, hwy::ThreadPool& pool, Gemma& model); + +KVCache CreateKVCache(Model type); // convenient workaround for now +KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len); + +// StreamFunc is called with (token, probability). For prompt tokens, +// probability is 0.0f. +using StreamFunc = std::function; +using AcceptFunc = std::function; + void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, float temperature, const std::vector& prompt, size_t start_pos, KVCache& kv_cache, hwy::ThreadPool& pool, diff --git a/run.cc b/run.cc index 40be63e..64b6399 100644 --- a/run.cc +++ b/run.cc @@ -236,6 +236,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { } gcpp::Gemma model(loader, pool); + auto kv_cache = CreateKVCache(loader.ModelType()); if (const char* error = inference.Validate()) { From 42e53e2da89f80dc46399c7037fbbfb15cdc3de3 Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Fri, 8 Mar 2024 14:55:35 -0500 Subject: [PATCH 12/20] [WIP] simplify hello world example, add convenience function. TODO: update git hash in CMakeLists.txt of hello world after push --- examples/hello_world/run.cc | 28 +++++++++++++++++----------- gemma.cc | 11 +++++++++++ gemma.h | 22 +++++++++++----------- 3 files changed, 39 insertions(+), 22 deletions(-) diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index a972154..fd6c762 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -21,20 +21,27 @@ std::vector tokenize( int main(int argc, char** argv) { gcpp::LoaderArgs loader(argc, argv); - // A rough heuristic for a reasonable number of threads given hardware - // concurrency estimate + + // A rough heuristic number of threads to use size_t num_threads = static_cast(std::clamp( static_cast(std::thread::hardware_concurrency()) - 2, 1, 18)); hwy::ThreadPool pool(num_threads); - hwy::ThreadPool inner_pool(0); + + // Instantiate model gcpp::Gemma model(loader, pool); + + // Setup random number generator std::mt19937 gen; std::random_device rd; gen.seed(rd()); + + // Tokenize instruction std::vector tokens = tokenize("Write a greeting to the world.", model.Tokenizer()); size_t ntokens = tokens.size(); size_t pos = 0; + + // Callback auto stream_token = [&pos, &gen, &ntokens, tokenizer = model.Tokenizer()]( int token, float) { ++pos; @@ -43,17 +50,16 @@ int main(int argc, char** argv) { } else if (token != gcpp::EOS_ID) { std::string token_text; HWY_ASSERT(tokenizer->Decode(std::vector{token}, &token_text).ok()); - if (pos == ntokens + 1) { - // first token of response - token_text.erase(0, token_text.find_first_not_of(" \t\n\n")); - } std::cout << token_text << std::flush; } return true; }; - GenerateGemma( - model, /*max_tokens=*/2048, /*max_generated_tokens=*/1024, - /*temperature=*/1.0, tokens, 0, pool, inner_pool, stream_token, - [](int) { return true; }, gen, 0); + + GenerateGemma(model, + {.max_tokens = 2048, + .max_generated_tokens = 1024, + .temperature = 1.0, + .verbosity = 0}, + tokens, /*KV cache position = */ 0, pool, stream_token, gen); std::cout << std::endl; } diff --git a/gemma.cc b/gemma.cc index f080dda..9c2df6c 100644 --- a/gemma.cc +++ b/gemma.cc @@ -844,5 +844,16 @@ void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, pool.SetWaitMode(hwy::PoolWaitMode::kBlock); } +void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config, + const std::vector& prompt, size_t start_pos, + KVCache& kv_cache, hwy::ThreadPool& pool, + const StreamFunc& stream_token, std::mt19937& gen) { + hwy::ThreadPool inner_pool(0); + GenerateGemma( + gemma, runtime_config.max_tokens, runtime_config.max_generated_tokens, + runtime_config.temperature, prompt, start_pos, kv_cache, pool, inner_pool, + stream_token, [](int) { return true; }, gen, runtime_config.verbosity); +} + } // namespace gcpp #endif // HWY_ONCE diff --git a/gemma.h b/gemma.h index 1a6ca07..3bb4c28 100644 --- a/gemma.h +++ b/gemma.h @@ -64,17 +64,11 @@ struct KVCache { enum class Model { GEMMA_2B, GEMMA_7B }; enum class ModelTraining { GEMMA_IT, GEMMA_PT }; -// TODO: Incorporate this -struct Runtime { - // TODO: In the future we may fold ModelTraining into Model. - // As we add more variations of model_type, the cartesian set becomes - // unwieldy. - Model model_type; - ModelTraining model_training; +struct RuntimeConfig { size_t max_tokens; size_t max_generated_tokens; float temperature; - std::mt19937 gen; + int verbosity; }; struct LoaderArgs : public ArgsBase { @@ -205,9 +199,6 @@ struct Gemma { gcpp::ModelTraining model_training; }; -struct LoaderArgs; // forward declaration -void CreateGemma(const LoaderArgs& args, hwy::ThreadPool& pool, Gemma& model); - KVCache CreateKVCache(Model type); // convenient workaround for now KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len); @@ -223,6 +214,15 @@ void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity); +// Convenience function for the common case: +// - Bundle runtime parameters as RuntimeConfig +// - No threadpools within threadpools (inner_pool = dummy) +// - All tokens accepted +void GenerateGemma(Gemma& gemma, RuntimeConfig runtime_config, + const std::vector& prompt, size_t start_pos, + KVCache& kv_cache, hwy::ThreadPool& pool, + const StreamFunc& stream_token, std::mt19937& gen); + constexpr int EOS_ID = 1; } // namespace gcpp From dfd2fdc1dd8e7a84d2e2f9618334b87a79ba02b1 Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Fri, 8 Mar 2024 17:26:03 -0500 Subject: [PATCH 13/20] Decouple gemma constructor from loader args, update hello_world example, add convenience version of constructor (no uncompressed weights) --- examples/hello_world/CMakeLists.txt | 13 ++++++++++++- examples/hello_world/run.cc | 14 ++++++++------ gemma.cc | 17 +++++++++++------ gemma.h | 5 ++++- run.cc | 2 +- 5 files changed, 36 insertions(+), 15 deletions(-) diff --git a/examples/hello_world/CMakeLists.txt b/examples/hello_world/CMakeLists.txt index 088af84..43159d7 100644 --- a/examples/hello_world/CMakeLists.txt +++ b/examples/hello_world/CMakeLists.txt @@ -22,7 +22,18 @@ FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.gi FetchContent_MakeAvailable(highway) FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) FetchContent_MakeAvailable(sentencepiece) -FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG e781007836ec034236e90cc4d313d0a8c481bce6) + + +# Allow for both local and remote building) +option(BUILD_MODE "'local' or 'remote' git fetch for builds") +if (NOT BUILD_MODE) + set(BUILD_MODE "remote") +endif() +if (BUILD_MODE STREQUAL "local") + FetchContent_Declare(gemma SOURCE_DIR ../../..) +else() + FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG 42e53e2da89f80dc46399c7037fbbfb15cdc3de3) +endif() FetchContent_MakeAvailable(gemma) if(NOT CMAKE_BUILD_TYPE) diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index fd6c762..8b5de24 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -22,15 +22,17 @@ std::vector tokenize( int main(int argc, char** argv) { gcpp::LoaderArgs loader(argc, argv); - // A rough heuristic number of threads to use + // Rough heuristic for the number of threads to use size_t num_threads = static_cast(std::clamp( static_cast(std::thread::hardware_concurrency()) - 2, 1, 18)); hwy::ThreadPool pool(num_threads); - // Instantiate model - gcpp::Gemma model(loader, pool); + // Instantiate model and KV Cache + gcpp::Gemma model(loader.tokenizer, loader.cache, loader.ModelType(), pool); + auto kv_cache = CreateKVCache(loader.ModelType()); + size_t pos = 0; // KV Cache position - // Setup random number generator + // Initialize random number generator std::mt19937 gen; std::random_device rd; gen.seed(rd()); @@ -39,7 +41,6 @@ int main(int argc, char** argv) { std::vector tokens = tokenize("Write a greeting to the world.", model.Tokenizer()); size_t ntokens = tokens.size(); - size_t pos = 0; // Callback auto stream_token = [&pos, &gen, &ntokens, tokenizer = model.Tokenizer()]( @@ -60,6 +61,7 @@ int main(int argc, char** argv) { .max_generated_tokens = 1024, .temperature = 1.0, .verbosity = 0}, - tokens, /*KV cache position = */ 0, pool, stream_token, gen); + tokens, /*KV cache position = */ 0, kv_cache, pool, + stream_token, gen); std::cout << std::endl; } diff --git a/gemma.cc b/gemma.cc index 9c2df6c..15b3c26 100644 --- a/gemma.cc +++ b/gemma.cc @@ -285,7 +285,6 @@ struct GemmaImpl : public GemmaInterface { int verbosity) override; std::unique_ptr tokenizer; - hwy::AlignedFreeUniquePtr compressed_weights; hwy::AlignedUniquePtr> prefill; hwy::AlignedUniquePtr> state; @@ -803,15 +802,15 @@ void GemmaImpl::Generate( kv_cache, pool, inner_pool, stream_token, accept_token, gen, verbosity); } -Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) { - const Model model_type = args.ModelType(); - model_training = args.ModelTraining(); +Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, + const Path& weights_path, Model model_type, + hwy::ThreadPool& pool) { PROFILER_ZONE("Startup.tokenizer"); std::unique_ptr tokenizer = std::make_unique(); - HWY_ASSERT(tokenizer->Load(args.tokenizer.path).ok()); + HWY_ASSERT(tokenizer->Load(tokenizer_path.path).ok()); auto compressed_weights = HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)( - args.ModelType(), args.model, args.cache, pool); + model_type, weights_path, compressed_weights_path, pool); switch (model_type) { case Model::GEMMA_2B: impl_.reset( @@ -825,6 +824,12 @@ Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) { HWY_ABORT("Model type %d unknown.", static_cast(model_type)); } } + +Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, + Model model_type, hwy::ThreadPool& pool) + : Gemma(tokenizer_path, compressed_weights_path, Path{""}, model_type, + pool) {} + Gemma::~Gemma() = default; // after GemmaInterface is defined const sentencepiece::SentencePieceProcessor* Gemma::Tokenizer() const { diff --git a/gemma.h b/gemma.h index 3bb4c28..7c08412 100644 --- a/gemma.h +++ b/gemma.h @@ -192,7 +192,10 @@ struct InferenceArgs : public ArgsBase { struct GemmaInterface; struct Gemma { - Gemma(const LoaderArgs& args, hwy::ThreadPool& pool); + Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, + const Path& weights_path, Model model_type, hwy::ThreadPool& pool); + Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, + Model model_type, hwy::ThreadPool& pool); ~Gemma(); // must be defined after GemmaInterface's dtor is defined. const sentencepiece::SentencePieceProcessor* Tokenizer() const; std::unique_ptr impl_; diff --git a/run.cc b/run.cc index 64b6399..cdeb95e 100644 --- a/run.cc +++ b/run.cc @@ -235,7 +235,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { [](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); }); } - gcpp::Gemma model(loader, pool); + gcpp::Gemma model(loader.tokenizer, loader.cache, loader.ModelType(), pool); auto kv_cache = CreateKVCache(loader.ModelType()); From 03147effbd241530611c99f46df51ce5b87f2d8a Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Fri, 8 Mar 2024 17:32:36 -0500 Subject: [PATCH 14/20] update loader arg names: cache -> compressed_weights, model -> weights --- examples/hello_world/CMakeLists.txt | 2 +- examples/hello_world/run.cc | 3 ++- gemma.h | 10 +++++----- run.cc | 3 ++- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/examples/hello_world/CMakeLists.txt b/examples/hello_world/CMakeLists.txt index 43159d7..97686dd 100644 --- a/examples/hello_world/CMakeLists.txt +++ b/examples/hello_world/CMakeLists.txt @@ -32,7 +32,7 @@ endif() if (BUILD_MODE STREQUAL "local") FetchContent_Declare(gemma SOURCE_DIR ../../..) else() - FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG 42e53e2da89f80dc46399c7037fbbfb15cdc3de3) + FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG dfd2fdc1dd8e7a84d2e2f9618334b87a79ba02b1) endif() FetchContent_MakeAvailable(gemma) diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index 8b5de24..7b57403 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -28,7 +28,8 @@ int main(int argc, char** argv) { hwy::ThreadPool pool(num_threads); // Instantiate model and KV Cache - gcpp::Gemma model(loader.tokenizer, loader.cache, loader.ModelType(), pool); + gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, + loader.ModelType(), pool); auto kv_cache = CreateKVCache(loader.ModelType()); size_t pos = 0; // KV Cache position diff --git a/gemma.h b/gemma.h index 7c08412..48fb52b 100644 --- a/gemma.h +++ b/gemma.h @@ -114,7 +114,7 @@ struct LoaderArgs : public ArgsBase { if (tokenizer.path.empty()) { return "Missing --tokenizer flag, a file for the tokenizer is required."; } - if (cache.path.empty()) { + if (compressed_weights.path.empty()) { return "Missing --compressed_weights flag, a file for the compressed " "model."; } @@ -122,8 +122,8 @@ struct LoaderArgs : public ArgsBase { } Path tokenizer; - Path model; // uncompressed weights OR - Path cache; // compressed weights (TODO: update name) + Path weights; // uncompressed weights file location + Path compressed_weights; // compressed weights file location std::string model_type; template @@ -131,7 +131,7 @@ struct LoaderArgs : public ArgsBase { visitor(tokenizer, "tokenizer", Path(), "Path name of tokenizer model file.\n Required argument."); visitor( - cache, "compressed_weights", Path(), + compressed_weights, "compressed_weights", Path(), "Path name of compressed weights file, regenerated from `--weights` " "file if " "the compressed weights file does not exist.\n Required argument."); @@ -140,7 +140,7 @@ struct LoaderArgs : public ArgsBase { "2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters " "instruction-tuned\n 7b-pt = 7B parameters, pretrained\n" " Required argument."); - visitor(model, "weights", Path(), + visitor(weights, "weights", Path(), "Path name of model weights (.sbs) file. Only required if " "compressed_weights file is not present and needs to be " "regenerated. This parameter is only required for compressing" diff --git a/run.cc b/run.cc index cdeb95e..c9fa78d 100644 --- a/run.cc +++ b/run.cc @@ -235,7 +235,8 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { [](uint64_t /*task*/, size_t thread) { PinThreadToCore(thread); }); } - gcpp::Gemma model(loader.tokenizer, loader.cache, loader.ModelType(), pool); + gcpp::Gemma model(loader.tokenizer, loader.compressed_weights, + loader.ModelType(), pool); auto kv_cache = CreateKVCache(loader.ModelType()); From 571a5449c4e1f45377582d26c6400215d6e5797a Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Fri, 8 Mar 2024 17:33:33 -0500 Subject: [PATCH 15/20] update commit hash for gemma lib --- examples/hello_world/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/hello_world/CMakeLists.txt b/examples/hello_world/CMakeLists.txt index 97686dd..eb574aa 100644 --- a/examples/hello_world/CMakeLists.txt +++ b/examples/hello_world/CMakeLists.txt @@ -32,7 +32,7 @@ endif() if (BUILD_MODE STREQUAL "local") FetchContent_Declare(gemma SOURCE_DIR ../../..) else() - FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG dfd2fdc1dd8e7a84d2e2f9618334b87a79ba02b1) + FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG 03147effbd241530611c99f46df51ce5b87f2d8a) endif() FetchContent_MakeAvailable(gemma) From 8c7b2cf61b9794b806de091685dc6739dd3db837 Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Fri, 8 Mar 2024 17:59:54 -0500 Subject: [PATCH 16/20] add README, license to hello_world --- examples/hello_world/CMakeLists.txt | 1 + examples/hello_world/README.md | 41 +++++++++++++++++++++++++++++ examples/hello_world/run.cc | 17 +++++++++++- 3 files changed, 58 insertions(+), 1 deletion(-) create mode 100644 examples/hello_world/README.md diff --git a/examples/hello_world/CMakeLists.txt b/examples/hello_world/CMakeLists.txt index eb574aa..b49d02e 100644 --- a/examples/hello_world/CMakeLists.txt +++ b/examples/hello_world/CMakeLists.txt @@ -24,6 +24,7 @@ FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sent FetchContent_MakeAvailable(sentencepiece) + # Allow for both local and remote building) option(BUILD_MODE "'local' or 'remote' git fetch for builds") if (NOT BUILD_MODE) diff --git a/examples/hello_world/README.md b/examples/hello_world/README.md new file mode 100644 index 0000000..fc117fe --- /dev/null +++ b/examples/hello_world/README.md @@ -0,0 +1,41 @@ +# Hello World Example + +This is a minimal/template project for using `gemma.cpp` as a library. Instead of an interactive interface, it sets up the model state and generates text for a single hard coded prompt. + +Build steps are similar to the main `gemma` executable. From inside the top-level directory. For now only `cmake`/`make` is available for builds (PRs welcome for other build options). + +First use `cmake` to configure the project, assuming you are in the `hello_world` example directory (`gemma.cpp/examples/hello_world`): + +```sh +cmake -B build +``` + +This sets up a build configuration in `gemma.cpp/examples/hello_world/build`. Note that this fetches `libgemma` from a git commit hash on github. Alternatively if you want to build using the local version of `gemma.cpp` use: + + +```sh +cmake -B build -DBUILD_MODE=local +``` + +Make sure you delete the contents of the build directory before changing configurations. + +Then use `make` to build the project: + +```sh +cd build +make hello_world +``` + +As with the top-level `gemma.cpp` project you can use the `make` commands `-j` flag to use parallel threads for faster builds. + +From inside the `gemma.cpp/examples/hello_world/build` directory, there should be a `hello_world` executable. You can run it with the same 3 model arguments as gemma.cpp specifying the tokenizer, compressed weights file, and model type, for example: + +```sh +./hello_world --tokenizer tokenizer.spm --compressed_weights 2b-it-sfp.sbs --model 2b-it +``` + +Should print a greeting to the terminal: + +``` +"Hello, world! It's a pleasure to greet you all. May your day be filled with joy, peace, and all the things that make your heart soar. +``` diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index 7b57403..fc74130 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -1,3 +1,18 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include // copybara:import_next_line:gemma_cpp @@ -43,7 +58,7 @@ int main(int argc, char** argv) { tokenize("Write a greeting to the world.", model.Tokenizer()); size_t ntokens = tokens.size(); - // Callback + // This callback function gets invoked everytime a token is generated auto stream_token = [&pos, &gen, &ntokens, tokenizer = model.Tokenizer()]( int token, float) { ++pos; From cc5c24c4f8ce2968e1c6ac1e2822cf15fb81d0b5 Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Fri, 8 Mar 2024 18:06:43 -0500 Subject: [PATCH 17/20] remove app.h dependency + fix bazel build --- examples/hello_world/CMakeLists.txt | 2 +- gemma.cc | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/hello_world/CMakeLists.txt b/examples/hello_world/CMakeLists.txt index b49d02e..56d7f77 100644 --- a/examples/hello_world/CMakeLists.txt +++ b/examples/hello_world/CMakeLists.txt @@ -33,7 +33,7 @@ endif() if (BUILD_MODE STREQUAL "local") FetchContent_Declare(gemma SOURCE_DIR ../../..) else() - FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG 03147effbd241530611c99f46df51ce5b87f2d8a) + FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG 8c7b2cf61b9794b806de091685dc6739dd3db837) endif() FetchContent_MakeAvailable(gemma) diff --git a/gemma.cc b/gemma.cc index 15b3c26..9eed24a 100644 --- a/gemma.cc +++ b/gemma.cc @@ -30,7 +30,6 @@ #include "hwy/highway.h" #include "hwy/profiler.h" #include "hwy/timer.h" -#include "util/app.h" // arg types #include "util/args.h" // Path // Non-SIMD includes and types. Note that HWY_ONCE is only true on the last From 0fc80fad05ccd18947c5667119351d7221fea8c5 Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Sun, 10 Mar 2024 12:55:08 -0400 Subject: [PATCH 18/20] libgemma refactor - review changes --- examples/hello_world/CMakeLists.txt | 7 +++---- examples/hello_world/README.md | 28 +++++++++++++++++++--------- examples/hello_world/run.cc | 8 ++++---- gemma.cc | 27 ++++++++++++--------------- run.cc | 1 - 5 files changed, 38 insertions(+), 33 deletions(-) diff --git a/examples/hello_world/CMakeLists.txt b/examples/hello_world/CMakeLists.txt index 56d7f77..397b957 100644 --- a/examples/hello_world/CMakeLists.txt +++ b/examples/hello_world/CMakeLists.txt @@ -14,7 +14,6 @@ cmake_minimum_required(VERSION 3.11) project(hello_world) -set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) include(FetchContent) @@ -31,18 +30,18 @@ if (NOT BUILD_MODE) set(BUILD_MODE "remote") endif() if (BUILD_MODE STREQUAL "local") - FetchContent_Declare(gemma SOURCE_DIR ../../..) + # Relative path to gemma.cpp from examples/hello_world/build/ + FetchContent_Declare(gemma SOURCE_DIR ../../..) else() FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG 8c7b2cf61b9794b806de091685dc6739dd3db837) endif() -FetchContent_MakeAvailable(gemma) +FetchContent_MakeAvailabl(gemma) if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE "Release") endif() add_executable(hello_world run.cc) -set_property(TARGET hello_world PROPERTY CXX_STANDARD 17) target_link_libraries(hello_world hwy hwy_contrib sentencepiece libgemma) FetchContent_GetProperties(sentencepiece) target_include_directories(hello_world PRIVATE ${sentencepiece_SOURCE_DIR}) diff --git a/examples/hello_world/README.md b/examples/hello_world/README.md index fc117fe..63c319e 100644 --- a/examples/hello_world/README.md +++ b/examples/hello_world/README.md @@ -1,23 +1,29 @@ # Hello World Example -This is a minimal/template project for using `gemma.cpp` as a library. Instead of an interactive interface, it sets up the model state and generates text for a single hard coded prompt. +This is a minimal/template project for using `gemma.cpp` as a library. Instead +of an interactive interface, it sets up the model state and generates text for a +single hard coded prompt. -Build steps are similar to the main `gemma` executable. From inside the top-level directory. For now only `cmake`/`make` is available for builds (PRs welcome for other build options). +Build steps are similar to the main `gemma` executable. For now only +`cmake`/`make` is available for builds (PRs welcome for other build options). -First use `cmake` to configure the project, assuming you are in the `hello_world` example directory (`gemma.cpp/examples/hello_world`): +First use `cmake` to configure the project, starting from the `hello_world` +example directory (`gemma.cpp/examples/hello_world`): ```sh cmake -B build ``` -This sets up a build configuration in `gemma.cpp/examples/hello_world/build`. Note that this fetches `libgemma` from a git commit hash on github. Alternatively if you want to build using the local version of `gemma.cpp` use: - +This sets up a build configuration in `gemma.cpp/examples/hello_world/build`. +Note that this fetches `libgemma` from a git commit hash on github. +Alternatively if you want to build using the local version of `gemma.cpp` use: ```sh cmake -B build -DBUILD_MODE=local ``` -Make sure you delete the contents of the build directory before changing configurations. +Make sure you delete the contents of the build directory before changing +configurations. Then use `make` to build the project: @@ -26,15 +32,19 @@ cd build make hello_world ``` -As with the top-level `gemma.cpp` project you can use the `make` commands `-j` flag to use parallel threads for faster builds. +As with the top-level `gemma.cpp` project you can use the `make` commands `-j` +flag to use parallel threads for faster builds. -From inside the `gemma.cpp/examples/hello_world/build` directory, there should be a `hello_world` executable. You can run it with the same 3 model arguments as gemma.cpp specifying the tokenizer, compressed weights file, and model type, for example: +From inside the `gemma.cpp/examples/hello_world/build` directory, there should +be a `hello_world` executable. You can run it with the same 3 model arguments as +gemma.cpp specifying the tokenizer, compressed weights file, and model type, for +example: ```sh ./hello_world --tokenizer tokenizer.spm --compressed_weights 2b-it-sfp.sbs --model 2b-it ``` -Should print a greeting to the terminal: +Should print a greeting to the terminal: ``` "Hello, world! It's a pleasure to greet you all. May your day be filled with joy, peace, and all the things that make your heart soar. diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index fc74130..8ec784f 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -24,12 +24,12 @@ #include "hwy/contrib/thread_pool/thread_pool.h" std::vector tokenize( - std::string prompt_string, + const std::string& prompt_string, const sentencepiece::SentencePieceProcessor* tokenizer) { - prompt_string = "user\n" + prompt_string + - "\nmodel\n"; + std::string formatted = "user\n" + prompt_string + + "\nmodel\n"; std::vector tokens; - HWY_ASSERT(tokenizer->Encode(prompt_string, &tokens).ok()); + HWY_ASSERT(tokenizer->Encode(formatted, &tokens).ok()); tokens.insert(tokens.begin(), 2); // BOS token return tokens; } diff --git a/gemma.cc b/gemma.cc index 9eed24a..6f2e664 100644 --- a/gemma.cc +++ b/gemma.cc @@ -261,10 +261,9 @@ KVCache CreateKVCache(Model type) { template struct GemmaImpl : public GemmaInterface { - GemmaImpl( // const LoaderArgs& args, - std::unique_ptr& tokenizer, - hwy::AlignedFreeUniquePtr& compressed_weights, - hwy::ThreadPool& pool); + GemmaImpl(std::unique_ptr& tokenizer, + hwy::AlignedFreeUniquePtr& compressed_weights, + hwy::ThreadPool& pool); ~GemmaImpl() { using CWeights = CompressedWeights; @@ -767,16 +766,10 @@ GemmaImpl::GemmaImpl( std::unique_ptr& tokenizer, hwy::AlignedFreeUniquePtr& compressed_weights, hwy::ThreadPool& pool) - // GemmaImpl::GemmaImpl(const LoaderArgs& args, hwy::ThreadPool& - // pool) : compressed_weights(std::move(compressed_weights)), - // HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)(args, pool)), prefill(hwy::MakeUniqueAligned>()), state(hwy::MakeUniqueAligned>()), - tokenizer(std::move(tokenizer)) { - // PROFILER_ZONE("Startup.tokenizer"); - // HWY_ASSERT(tokenizer.Load(args.tokenizer.path).ok()); -} + tokenizer(std::move(tokenizer)) {} template <> void GemmaImpl::Generate( @@ -804,10 +797,14 @@ void GemmaImpl::Generate( Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, const Path& weights_path, Model model_type, hwy::ThreadPool& pool) { - PROFILER_ZONE("Startup.tokenizer"); - std::unique_ptr tokenizer = - std::make_unique(); - HWY_ASSERT(tokenizer->Load(tokenizer_path.path).ok()); + { + PROFILER_ZONE("Startup.tokenizer"); + std::unique_ptr tokenizer = + std::make_unique(); + if (!tokenizer->Load(tokenizer_path.path).ok()) { + HWY_ABORT("Failed to load the tokenizer file."); + } + } auto compressed_weights = HWY_DYNAMIC_DISPATCH(GetCompressedWeightsT)( model_type, weights_path, compressed_weights_path, pool); switch (model_type) { diff --git a/run.cc b/run.cc index c9fa78d..1487a8a 100644 --- a/run.cc +++ b/run.cc @@ -190,7 +190,6 @@ void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache, } } - // HWY_ASSERT(model.Tokenizer().Encode(prompt_string, &prompt).ok()); HWY_ASSERT(model.Tokenizer()->Encode(prompt_string, &prompt).ok()); // For both pre-trained and instruction-tuned models: prepend "" token From 5d323c00fe6871fe5ba68afd2a995fc5667ee1f7 Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Sun, 10 Mar 2024 13:23:16 -0400 Subject: [PATCH 19/20] fix tokenizer scope --- gemma.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gemma.cc b/gemma.cc index 6f2e664..c9066eb 100644 --- a/gemma.cc +++ b/gemma.cc @@ -797,10 +797,10 @@ void GemmaImpl::Generate( Gemma::Gemma(const Path& tokenizer_path, const Path& compressed_weights_path, const Path& weights_path, Model model_type, hwy::ThreadPool& pool) { + std::unique_ptr tokenizer; { PROFILER_ZONE("Startup.tokenizer"); - std::unique_ptr tokenizer = - std::make_unique(); + tokenizer = std::make_unique(); if (!tokenizer->Load(tokenizer_path.path).ok()) { HWY_ABORT("Failed to load the tokenizer file."); } From 415464b047829d4fd64a6b7fde2b1c0ba843b793 Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Sun, 10 Mar 2024 15:41:17 -0400 Subject: [PATCH 20/20] fix CMakeLists typo --- examples/hello_world/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/hello_world/CMakeLists.txt b/examples/hello_world/CMakeLists.txt index 397b957..9d44f04 100644 --- a/examples/hello_world/CMakeLists.txt +++ b/examples/hello_world/CMakeLists.txt @@ -35,7 +35,7 @@ if (BUILD_MODE STREQUAL "local") else() FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG 8c7b2cf61b9794b806de091685dc6739dd3db837) endif() -FetchContent_MakeAvailabl(gemma) +FetchContent_MakeAvailable(gemma) if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE "Release")