mirror of https://github.com/google/gemma.cpp.git
Fix handling of %c and %q if eot_string. Fixes #283, thanks @ljcucc
PiperOrigin-RevId: 649651535
This commit is contained in:
parent
f823371691
commit
438b1bace2
56
gemma/run.cc
56
gemma/run.cc
|
|
@ -52,6 +52,28 @@ static constexpr std::string_view kAsciiArtBanner = R""(
|
||||||
|___/ |_| |_|
|
|___/ |_| |_|
|
||||||
)"";
|
)"";
|
||||||
|
|
||||||
|
std::string GetPrompt(std::istream& input, int verbosity,
|
||||||
|
std::string_view eot_line) {
|
||||||
|
PROFILER_ZONE("Gen.input");
|
||||||
|
if (verbosity >= 1) {
|
||||||
|
std::cout << "> " << std::flush;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string prompt_string;
|
||||||
|
if (eot_line.empty()) {
|
||||||
|
std::getline(input, prompt_string);
|
||||||
|
} else {
|
||||||
|
std::string line;
|
||||||
|
while (std::getline(input, line)) {
|
||||||
|
if (line == eot_line) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
prompt_string += line + "\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return prompt_string;
|
||||||
|
}
|
||||||
|
|
||||||
// The main Read-Eval-Print Loop.
|
// The main Read-Eval-Print Loop.
|
||||||
void ReplGemma(Gemma& model, KVCache& kv_cache, hwy::ThreadPool& pool,
|
void ReplGemma(Gemma& model, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
const InferenceArgs& args, int verbosity,
|
const InferenceArgs& args, int verbosity,
|
||||||
|
|
@ -100,34 +122,16 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, hwy::ThreadPool& pool,
|
||||||
};
|
};
|
||||||
|
|
||||||
while (abs_pos < args.max_tokens) {
|
while (abs_pos < args.max_tokens) {
|
||||||
std::string prompt_string;
|
|
||||||
current_pos = 0;
|
current_pos = 0;
|
||||||
{
|
std::string prompt_string = GetPrompt(std::cin, verbosity, eot_line);
|
||||||
PROFILER_ZONE("Gen.input");
|
if (!std::cin) return;
|
||||||
if (verbosity >= 1) {
|
// If !eot_line.empty(), we append \n, so only look at the first 2 chars.
|
||||||
std::cout << "> " << std::flush;
|
if (prompt_string.size() >= 2 && prompt_string[0] == '%') {
|
||||||
|
if (prompt_string[1] == 'q' || prompt_string[1] == 'Q') return;
|
||||||
|
if (prompt_string[1] == 'c' || prompt_string[1] == 'C') {
|
||||||
|
abs_pos = 0;
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (eot_line.empty()) {
|
|
||||||
std::getline(std::cin, prompt_string);
|
|
||||||
} else {
|
|
||||||
std::string line;
|
|
||||||
while (std::getline(std::cin, line)) {
|
|
||||||
if (line == eot_line) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
prompt_string += line + "\n";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (std::cin.fail() || prompt_string == "%q" || prompt_string == "%Q") {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (prompt_string == "%c" || prompt_string == "%C") {
|
|
||||||
abs_pos = 0;
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<int> prompt = WrapAndTokenize(
|
const std::vector<int> prompt = WrapAndTokenize(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue