Fix handling of %c and %q if eot_string. Fixes #283, thanks @ljcucc

PiperOrigin-RevId: 649651535
This commit is contained in:
Jan Wassenberg 2024-07-05 07:53:24 -07:00 committed by Copybara-Service
parent f823371691
commit 438b1bace2
1 changed files with 30 additions and 26 deletions

View File

@ -52,6 +52,28 @@ static constexpr std::string_view kAsciiArtBanner = R""(
|___/ |_| |_|
)"";
std::string GetPrompt(std::istream& input, int verbosity,
std::string_view eot_line) {
PROFILER_ZONE("Gen.input");
if (verbosity >= 1) {
std::cout << "> " << std::flush;
}
std::string prompt_string;
if (eot_line.empty()) {
std::getline(input, prompt_string);
} else {
std::string line;
while (std::getline(input, line)) {
if (line == eot_line) {
break;
}
prompt_string += line + "\n";
}
}
return prompt_string;
}
// The main Read-Eval-Print Loop.
void ReplGemma(Gemma& model, KVCache& kv_cache, hwy::ThreadPool& pool,
const InferenceArgs& args, int verbosity,
@ -100,34 +122,16 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, hwy::ThreadPool& pool,
};
while (abs_pos < args.max_tokens) {
std::string prompt_string;
current_pos = 0;
{
PROFILER_ZONE("Gen.input");
if (verbosity >= 1) {
std::cout << "> " << std::flush;
std::string prompt_string = GetPrompt(std::cin, verbosity, eot_line);
if (!std::cin) return;
// If !eot_line.empty(), we append \n, so only look at the first 2 chars.
if (prompt_string.size() >= 2 && prompt_string[0] == '%') {
if (prompt_string[1] == 'q' || prompt_string[1] == 'Q') return;
if (prompt_string[1] == 'c' || prompt_string[1] == 'C') {
abs_pos = 0;
continue;
}
if (eot_line.empty()) {
std::getline(std::cin, prompt_string);
} else {
std::string line;
while (std::getline(std::cin, line)) {
if (line == eot_line) {
break;
}
prompt_string += line + "\n";
}
}
}
if (std::cin.fail() || prompt_string == "%q" || prompt_string == "%Q") {
return;
}
if (prompt_string == "%c" || prompt_string == "%C") {
abs_pos = 0;
continue;
}
const std::vector<int> prompt = WrapAndTokenize(