Fix handling of %c and %q if eot_string. Fixes #283, thanks @ljcucc

PiperOrigin-RevId: 649651535
This commit is contained in:
Jan Wassenberg 2024-07-05 07:53:24 -07:00 committed by Copybara-Service
parent f823371691
commit 438b1bace2
1 changed files with 30 additions and 26 deletions

View File

@ -52,6 +52,28 @@ static constexpr std::string_view kAsciiArtBanner = R""(
|___/ |_| |_| |___/ |_| |_|
)""; )"";
std::string GetPrompt(std::istream& input, int verbosity,
std::string_view eot_line) {
PROFILER_ZONE("Gen.input");
if (verbosity >= 1) {
std::cout << "> " << std::flush;
}
std::string prompt_string;
if (eot_line.empty()) {
std::getline(input, prompt_string);
} else {
std::string line;
while (std::getline(input, line)) {
if (line == eot_line) {
break;
}
prompt_string += line + "\n";
}
}
return prompt_string;
}
// The main Read-Eval-Print Loop. // The main Read-Eval-Print Loop.
void ReplGemma(Gemma& model, KVCache& kv_cache, hwy::ThreadPool& pool, void ReplGemma(Gemma& model, KVCache& kv_cache, hwy::ThreadPool& pool,
const InferenceArgs& args, int verbosity, const InferenceArgs& args, int verbosity,
@ -100,35 +122,17 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, hwy::ThreadPool& pool,
}; };
while (abs_pos < args.max_tokens) { while (abs_pos < args.max_tokens) {
std::string prompt_string;
current_pos = 0; current_pos = 0;
{ std::string prompt_string = GetPrompt(std::cin, verbosity, eot_line);
PROFILER_ZONE("Gen.input"); if (!std::cin) return;
if (verbosity >= 1) { // If !eot_line.empty(), we append \n, so only look at the first 2 chars.
std::cout << "> " << std::flush; if (prompt_string.size() >= 2 && prompt_string[0] == '%') {
} if (prompt_string[1] == 'q' || prompt_string[1] == 'Q') return;
if (prompt_string[1] == 'c' || prompt_string[1] == 'C') {
if (eot_line.empty()) {
std::getline(std::cin, prompt_string);
} else {
std::string line;
while (std::getline(std::cin, line)) {
if (line == eot_line) {
break;
}
prompt_string += line + "\n";
}
}
}
if (std::cin.fail() || prompt_string == "%q" || prompt_string == "%Q") {
return;
}
if (prompt_string == "%c" || prompt_string == "%C") {
abs_pos = 0; abs_pos = 0;
continue; continue;
} }
}
const std::vector<int> prompt = WrapAndTokenize( const std::vector<int> prompt = WrapAndTokenize(
model.Tokenizer(), model.Info(), abs_pos, prompt_string); model.Tokenizer(), model.Info(), abs_pos, prompt_string);