Add --eot_line option

This commit is contained in:
Yuta Hayashibe 2024-02-24 20:25:07 +09:00
parent 7698e3c3de
commit 1a95cf3274
2 changed files with 22 additions and 3 deletions

19
run.cc
View File

@ -79,7 +79,9 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool, void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
hwy::ThreadPool& inner_pool, const InferenceArgs& args, hwy::ThreadPool& inner_pool, const InferenceArgs& args,
int verbosity, const gcpp::AcceptFunc& accept_token) { int verbosity, const gcpp::AcceptFunc& accept_token,
std::string &eot_line
) {
PROFILER_ZONE("Gen.misc"); PROFILER_ZONE("Gen.misc");
int abs_pos = 0; // absolute token index over all turns int abs_pos = 0; // absolute token index over all turns
int current_pos = 0; // token index within the current turn int current_pos = 0; // token index within the current turn
@ -137,7 +139,18 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
if (verbosity >= 1) { if (verbosity >= 1) {
std::cout << "> " << std::flush; std::cout << "> " << std::flush;
} }
std::getline(std::cin, prompt_string);
if (eot_line.size() == 0) {
std::getline(std::cin, prompt_string);
} else {
std::string line;
while (std::getline(std::cin, line)) {
if (line == eot_line) {
break;
}
prompt_string += line + "\n";
}
}
} }
if (std::cin.fail() || prompt_string == "%q" || prompt_string == "%Q") { if (std::cin.fail() || prompt_string == "%q" || prompt_string == "%Q") {
@ -231,7 +244,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
} }
ReplGemma(model, pool, inner_pool, inference, app.verbosity, ReplGemma(model, pool, inner_pool, inference, app.verbosity,
/*accept_token=*/[](int) { return true; }); /*accept_token=*/[](int) { return true; }, app.eot_line);
} }
} // namespace gcpp } // namespace gcpp

View File

@ -62,6 +62,7 @@ class AppArgs : public ArgsBase<AppArgs> {
Path log; // output Path log; // output
int verbosity; int verbosity;
size_t num_threads; size_t num_threads;
std::string eot_line;
template <class Visitor> template <class Visitor>
void ForEach(const Visitor& visitor) { void ForEach(const Visitor& visitor) {
@ -77,6 +78,11 @@ class AppArgs : public ArgsBase<AppArgs> {
"estimate of " "estimate of "
"how many concurrent threads are supported.", "how many concurrent threads are supported.",
2); 2);
visitor(eot_line, "eot_line", std::string(""),
"End of turn line. "
"When you specify this, the prompt will be all lines "
"before the line where only the given string appears.",
2);
} }
}; };