From 1a95cf32745ca3d75a3a09c948812f093113f1a0 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Sat, 24 Feb 2024 20:25:07 +0900 Subject: [PATCH] Add --eot_line option --- run.cc | 19 ++++++++++++++++--- util/app.h | 6 ++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/run.cc b/run.cc index 87d8445..526ea8f 100644 --- a/run.cc +++ b/run.cc @@ -79,7 +79,9 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, hwy::ThreadPool& inner_pool, const InferenceArgs& args, - int verbosity, const gcpp::AcceptFunc& accept_token) { + int verbosity, const gcpp::AcceptFunc& accept_token, + std::string &eot_line +) { PROFILER_ZONE("Gen.misc"); int abs_pos = 0; // absolute token index over all turns int current_pos = 0; // token index within the current turn @@ -137,7 +139,18 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, if (verbosity >= 1) { std::cout << "> " << std::flush; } - std::getline(std::cin, prompt_string); + + if (eot_line.size() == 0) { + std::getline(std::cin, prompt_string); + } else { + std::string line; + while (std::getline(std::cin, line)) { + if (line == eot_line) { + break; + } + prompt_string += line + "\n"; + } + } } if (std::cin.fail() || prompt_string == "%q" || prompt_string == "%Q") { @@ -231,7 +244,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { } ReplGemma(model, pool, inner_pool, inference, app.verbosity, - /*accept_token=*/[](int) { return true; }); + /*accept_token=*/[](int) { return true; }, app.eot_line); } } // namespace gcpp diff --git a/util/app.h b/util/app.h index 966fa41..8eb672b 100644 --- a/util/app.h +++ b/util/app.h @@ -62,6 +62,7 @@ class AppArgs : public ArgsBase { Path log; // output int verbosity; size_t num_threads; + std::string eot_line; template void ForEach(const Visitor& visitor) { @@ -77,6 +78,11 @@ class AppArgs : public ArgsBase { "estimate of " "how many concurrent threads are supported.", 2); + visitor(eot_line, "eot_line", std::string(""), + "End of turn line. " + "When you specify this, the prompt will be all lines " + "before the line where only the given string appears.", + 2); } };