From e781007836ec034236e90cc4d313d0a8c481bce6 Mon Sep 17 00:00:00 2001 From: austinvhuang Date: Wed, 6 Mar 2024 23:21:13 -0500 Subject: [PATCH] [WIP] Remove InferenceArgs from hello_world example, fix ordering of LoaderArgs validation, revert ReplGemma EOT token behavior --- examples/hello_world/CMakeLists.txt | 2 +- examples/hello_world/run.cc | 37 +++++++++++++---------------- gemma.cc | 8 ++++--- gemma.h | 8 +++---- run.cc | 2 +- 5 files changed, 28 insertions(+), 29 deletions(-) diff --git a/examples/hello_world/CMakeLists.txt b/examples/hello_world/CMakeLists.txt index 69e4e98..63c8d7b 100644 --- a/examples/hello_world/CMakeLists.txt +++ b/examples/hello_world/CMakeLists.txt @@ -27,7 +27,7 @@ FetchContent_MakeAvailable(highway) FetchContent_Declare(sentencepiece GIT_REPOSITORY https://github.com/google/sentencepiece GIT_TAG 53de76561cfc149d3c01037f0595669ad32a5e7c) FetchContent_MakeAvailable(sentencepiece) -FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG 808dbdc42b216c3ac1f1c40dfa638bcff24bbd2b) +FetchContent_Declare(gemma GIT_REPOSITORY https://github.com/google/gemma.cpp.git GIT_TAG 7042316013d7e2ad06532f420b551802ec114f89) FetchContent_MakeAvailable(gemma) if(NOT CMAKE_BUILD_TYPE) diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index 67953dc..227c91c 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -7,9 +7,6 @@ #include "gemma.h" // Gemma // copybara:end // copybara:import_next_line:gemma_cpp -#include "util/app.h" -// copybara:end -// copybara:import_next_line:gemma_cpp #include "util/args.h" // HasHelp // copybara:end // copybara:import_next_line:gemma_cpp @@ -22,30 +19,35 @@ #include "hwy/profiler.h" #include "hwy/timer.h" -std::vector tokenize(std::string prompt_string, const sentencepiece::SentencePieceProcessor& tokenizer) { +std::vector tokenize( + std::string prompt_string, + const sentencepiece::SentencePieceProcessor* tokenizer) { prompt_string = "user\n" + prompt_string + "\nmodel\n"; std::vector tokens; - HWY_ASSERT(tokenizer.Encode(prompt_string, &tokens).ok()); - tokens.insert(tokens.begin(), 2); // BOS token + HWY_ASSERT(tokenizer->Encode(prompt_string, &tokens).ok()); + tokens.insert(tokens.begin(), 2); // BOS token return tokens; } int main(int argc, char** argv) { - gcpp::InferenceArgs inference(argc, argv); gcpp::LoaderArgs loader(argc, argv); - gcpp::AppArgs app(argc, argv); - hwy::ThreadPool pool(app.num_threads); + // A rough heuristic for a reasonable number of threads given hardware + // concurrency estimate + size_t num_threads = static_cast(std::clamp( + static_cast(std::thread::hardware_concurrency()) - 2, 1, 18)); + hwy::ThreadPool pool(num_threads); hwy::ThreadPool inner_pool(0); gcpp::Gemma model(loader, pool); std::mt19937 gen; std::random_device rd; gen.seed(rd()); - - std::vector tokens = tokenize("Hello, how are you?", model.Tokenizer()); + std::vector tokens = + tokenize("Write a greeting to the world.", model.Tokenizer()); size_t ntokens = tokens.size(); size_t pos = 0; - auto stream_token = [&pos, &gen, &ntokens, tokenizer = &model.Tokenizer()](int token, float) { + auto stream_token = [&pos, &gen, &ntokens, tokenizer = model.Tokenizer()]( + int token, float) { ++pos; if (pos < ntokens) { // print feedback @@ -60,14 +62,9 @@ int main(int argc, char** argv) { } return true; }; - - inference.temperature = 1.0f; - inference.deterministic = true; - inference.multiturn = false; - GenerateGemma( - model, /*max_tokens=*/2048, /*max_generated_tokens=*/1024, /*temperature=*/1.0, tokens, 0, pool, inner_pool, stream_token, - [](int) {return true;}, gen, 0); - + model, /*max_tokens=*/2048, /*max_generated_tokens=*/1024, + /*temperature=*/1.0, tokens, 0, pool, inner_pool, stream_token, + [](int) { return true; }, gen, 0); std::cout << std::endl; } diff --git a/gemma.cc b/gemma.cc index d49d0df..bbd86c3 100644 --- a/gemma.cc +++ b/gemma.cc @@ -782,6 +782,8 @@ void GemmaImpl::Generate( pool, inner_pool, stream_token, accept_token, gen, verbosity); } +// TODO: Make Gemma type independent of LoaderArgs, create a factory function +// that takes LoaderArgs and creates a Gemma instance. Gemma::Gemma(const LoaderArgs& args, hwy::ThreadPool& pool) { const Model model_type = args.ModelType(); model_training = args.ModelTraining(); @@ -817,9 +819,9 @@ void GenerateGemma(Gemma& gemma, size_t max_tokens, size_t max_generated_tokens, const AcceptFunc& accept_token, std::mt19937& gen, int verbosity) { pool.SetWaitMode(hwy::PoolWaitMode::kSpin); - gemma.impl_->Generate(max_tokens, max_generated_tokens, - temperature, prompt, start_pos, pool, inner_pool, - stream_token, accept_token, gen, verbosity); + gemma.impl_->Generate(max_tokens, max_generated_tokens, temperature, prompt, + start_pos, pool, inner_pool, stream_token, accept_token, + gen, verbosity); pool.SetWaitMode(hwy::PoolWaitMode::kBlock); } diff --git a/gemma.h b/gemma.h index 5a2f2b0..3528b50 100644 --- a/gemma.h +++ b/gemma.h @@ -104,6 +104,10 @@ struct LoaderArgs : public ArgsBase { // Returns error string or nullptr if OK. const char* Validate() const { const std::string model_type_lc = ToLower(model_type); + if (model_type.empty()) { + return "Missing --model flag, need to specify either 2b-pt, 7b-pt, " + "2b-it, or 7b-it."; + } if (model_type_lc != "2b-pt" && model_type_lc != "7b-pt" && model_type_lc != "2b-it" && model_type_lc != "7b-it") { return "Model type must be 2b-pt, 7b-pt, 2b-it, or " @@ -112,10 +116,6 @@ struct LoaderArgs : public ArgsBase { if (tokenizer.path.empty()) { return "Missing --tokenizer flag, a file for the tokenizer is required."; } - if (model_type.empty()) { - return "Missing --model flag, need to specify either 2b-pt, 7b-pt, " - "2b-it, or 7b-it."; - } if (cache.path.empty()) { return "Missing --compressed_weights flag, a file for the compressed " "model."; diff --git a/run.cc b/run.cc index eac5f9e..4c4f132 100644 --- a/run.cc +++ b/run.cc @@ -186,7 +186,7 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, if (abs_pos > 0) { // Prepend "" token if this is a multi-turn dialogue // continuation. - prompt_string = "model\n" + prompt_string; + prompt_string = "\n" + prompt_string; } }