mirror of https://github.com/google/gemma.cpp.git
Fix handling of %c and %q if eot_string. Fixes #283, thanks @ljcucc
PiperOrigin-RevId: 649651535
This commit is contained in:
parent
f823371691
commit
438b1bace2
56
gemma/run.cc
56
gemma/run.cc
|
|
@ -52,6 +52,28 @@ static constexpr std::string_view kAsciiArtBanner = R""(
|
|||
|___/ |_| |_|
|
||||
)"";
|
||||
|
||||
std::string GetPrompt(std::istream& input, int verbosity,
|
||||
std::string_view eot_line) {
|
||||
PROFILER_ZONE("Gen.input");
|
||||
if (verbosity >= 1) {
|
||||
std::cout << "> " << std::flush;
|
||||
}
|
||||
|
||||
std::string prompt_string;
|
||||
if (eot_line.empty()) {
|
||||
std::getline(input, prompt_string);
|
||||
} else {
|
||||
std::string line;
|
||||
while (std::getline(input, line)) {
|
||||
if (line == eot_line) {
|
||||
break;
|
||||
}
|
||||
prompt_string += line + "\n";
|
||||
}
|
||||
}
|
||||
return prompt_string;
|
||||
}
|
||||
|
||||
// The main Read-Eval-Print Loop.
|
||||
void ReplGemma(Gemma& model, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||
const InferenceArgs& args, int verbosity,
|
||||
|
|
@ -100,34 +122,16 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, hwy::ThreadPool& pool,
|
|||
};
|
||||
|
||||
while (abs_pos < args.max_tokens) {
|
||||
std::string prompt_string;
|
||||
current_pos = 0;
|
||||
{
|
||||
PROFILER_ZONE("Gen.input");
|
||||
if (verbosity >= 1) {
|
||||
std::cout << "> " << std::flush;
|
||||
std::string prompt_string = GetPrompt(std::cin, verbosity, eot_line);
|
||||
if (!std::cin) return;
|
||||
// If !eot_line.empty(), we append \n, so only look at the first 2 chars.
|
||||
if (prompt_string.size() >= 2 && prompt_string[0] == '%') {
|
||||
if (prompt_string[1] == 'q' || prompt_string[1] == 'Q') return;
|
||||
if (prompt_string[1] == 'c' || prompt_string[1] == 'C') {
|
||||
abs_pos = 0;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (eot_line.empty()) {
|
||||
std::getline(std::cin, prompt_string);
|
||||
} else {
|
||||
std::string line;
|
||||
while (std::getline(std::cin, line)) {
|
||||
if (line == eot_line) {
|
||||
break;
|
||||
}
|
||||
prompt_string += line + "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (std::cin.fail() || prompt_string == "%q" || prompt_string == "%Q") {
|
||||
return;
|
||||
}
|
||||
|
||||
if (prompt_string == "%c" || prompt_string == "%C") {
|
||||
abs_pos = 0;
|
||||
continue;
|
||||
}
|
||||
|
||||
const std::vector<int> prompt = WrapAndTokenize(
|
||||
|
|
|
|||
Loading…
Reference in New Issue