diff --git a/gemma/run.cc b/gemma/run.cc index ccad159..cb2a2d5 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -52,6 +52,28 @@ static constexpr std::string_view kAsciiArtBanner = R""( |___/ |_| |_| )""; +std::string GetPrompt(std::istream& input, int verbosity, + std::string_view eot_line) { + PROFILER_ZONE("Gen.input"); + if (verbosity >= 1) { + std::cout << "> " << std::flush; + } + + std::string prompt_string; + if (eot_line.empty()) { + std::getline(input, prompt_string); + } else { + std::string line; + while (std::getline(input, line)) { + if (line == eot_line) { + break; + } + prompt_string += line + "\n"; + } + } + return prompt_string; +} + // The main Read-Eval-Print Loop. void ReplGemma(Gemma& model, KVCache& kv_cache, hwy::ThreadPool& pool, const InferenceArgs& args, int verbosity, @@ -100,34 +122,16 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, hwy::ThreadPool& pool, }; while (abs_pos < args.max_tokens) { - std::string prompt_string; current_pos = 0; - { - PROFILER_ZONE("Gen.input"); - if (verbosity >= 1) { - std::cout << "> " << std::flush; + std::string prompt_string = GetPrompt(std::cin, verbosity, eot_line); + if (!std::cin) return; + // If !eot_line.empty(), we append \n, so only look at the first 2 chars. + if (prompt_string.size() >= 2 && prompt_string[0] == '%') { + if (prompt_string[1] == 'q' || prompt_string[1] == 'Q') return; + if (prompt_string[1] == 'c' || prompt_string[1] == 'C') { + abs_pos = 0; + continue; } - - if (eot_line.empty()) { - std::getline(std::cin, prompt_string); - } else { - std::string line; - while (std::getline(std::cin, line)) { - if (line == eot_line) { - break; - } - prompt_string += line + "\n"; - } - } - } - - if (std::cin.fail() || prompt_string == "%q" || prompt_string == "%Q") { - return; - } - - if (prompt_string == "%c" || prompt_string == "%C") { - abs_pos = 0; - continue; } const std::vector prompt = WrapAndTokenize(