From 06f814fc8badd2e35d67e69581f1302e5bec94ee Mon Sep 17 00:00:00 2001 From: Daniel Keysers Date: Fri, 7 Jun 2024 05:32:50 -0700 Subject: [PATCH] Small code cleanup suggestions while reading the code. PiperOrigin-RevId: 641220788 --- gemma/gemma.h | 4 +++- gemma/run.cc | 20 +++++++++++--------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/gemma/gemma.h b/gemma/gemma.h index e9fcb2f..e9ffc08 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -44,7 +44,9 @@ struct KVCache { static KVCache Create(Model type); }; +// The tokenizer's end of sentence and beginning of sentence token ids. constexpr int EOS_ID = 1; +constexpr int BOS_ID = 2; class GemmaTokenizer { public: @@ -87,7 +89,7 @@ struct RuntimeConfig { struct TimingInfo { double prefill_tok_sec = 0.0; double gen_tok_sec = 0.0; - double time_to_first_token = 0; + double time_to_first_token = 0.0; }; // Will be called for layers output with: diff --git a/gemma/run.cc b/gemma/run.cc index 372346b..e30d67b 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -25,6 +25,7 @@ // Placeholder for internal header, do not modify. #include "compression/compress.h" +#include "gemma/common.h" #include "gemma/configs.h" #include "gemma/gemma.h" // Gemma #include "util/app.h" @@ -98,12 +99,13 @@ void ShowHelp(gcpp::LoaderArgs& loader, gcpp::InferenceArgs& inference, std::cerr << "\n"; } +// The main Read-Eval-Print Loop. void ReplGemma(gcpp::Gemma& model, ModelTraining training, gcpp::KVCache& kv_cache, hwy::ThreadPool& pool, const InferenceArgs& args, int verbosity, const gcpp::AcceptFunc& accept_token, std::string& eot_line) { PROFILER_ZONE("Gen.misc"); - size_t abs_pos = 0; // absolute token index over all turns + size_t abs_pos = 0; // absolute token index over all turns int current_pos = 0; // token index within the current turn int prompt_size{}; @@ -160,7 +162,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training, std::cout << "> " << std::flush; } - if (eot_line.size() == 0) { + if (eot_line.empty()) { std::getline(std::cin, prompt_string); } else { std::string line; @@ -198,7 +200,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training, // For both pre-trained and instruction-tuned models: prepend "" token // if needed. if (abs_pos == 0) { - prompt.insert(prompt.begin(), 2); + prompt.insert(prompt.begin(), gcpp::BOS_ID); } prompt_size = prompt.size(); @@ -207,7 +209,7 @@ void ReplGemma(gcpp::Gemma& model, ModelTraining training, << "[ Reading prompt ] " << std::flush; if constexpr (kVerboseLogTokens) { - for (int i = 0; i < static_cast(prompt.size()); ++i) { + for (int i = 0; i < prompt_size; ++i) { fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]); } } @@ -253,11 +255,6 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { KVCache kv_cache = KVCache::Create(loader.ModelType()); - if (const char* error = inference.Validate()) { - ShowHelp(loader, inference, app); - HWY_ABORT("\nInvalid args: %s", error); - } - if (app.verbosity >= 1) { const std::string instructions = "*Usage*\n" @@ -307,6 +304,11 @@ int main(int argc, char** argv) { HWY_ABORT("\nInvalid args: %s", error); } + if (const char* error = inference.Validate()) { + ShowHelp(loader, inference, app); + HWY_ABORT("\nInvalid args: %s", error); + } + gcpp::Run(loader, inference, app); } PROFILER_PRINT_RESULTS(); // Must call outside the zone above.