mirror of https://github.com/google/gemma.cpp.git
Add --eot_line option
This commit is contained in:
parent
7698e3c3de
commit
1a95cf3274
19
run.cc
19
run.cc
|
|
@ -79,7 +79,9 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
|
|
||||||
void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
|
void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
|
||||||
hwy::ThreadPool& inner_pool, const InferenceArgs& args,
|
hwy::ThreadPool& inner_pool, const InferenceArgs& args,
|
||||||
int verbosity, const gcpp::AcceptFunc& accept_token) {
|
int verbosity, const gcpp::AcceptFunc& accept_token,
|
||||||
|
std::string &eot_line
|
||||||
|
) {
|
||||||
PROFILER_ZONE("Gen.misc");
|
PROFILER_ZONE("Gen.misc");
|
||||||
int abs_pos = 0; // absolute token index over all turns
|
int abs_pos = 0; // absolute token index over all turns
|
||||||
int current_pos = 0; // token index within the current turn
|
int current_pos = 0; // token index within the current turn
|
||||||
|
|
@ -137,7 +139,18 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
|
||||||
if (verbosity >= 1) {
|
if (verbosity >= 1) {
|
||||||
std::cout << "> " << std::flush;
|
std::cout << "> " << std::flush;
|
||||||
}
|
}
|
||||||
std::getline(std::cin, prompt_string);
|
|
||||||
|
if (eot_line.size() == 0) {
|
||||||
|
std::getline(std::cin, prompt_string);
|
||||||
|
} else {
|
||||||
|
std::string line;
|
||||||
|
while (std::getline(std::cin, line)) {
|
||||||
|
if (line == eot_line) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
prompt_string += line + "\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (std::cin.fail() || prompt_string == "%q" || prompt_string == "%Q") {
|
if (std::cin.fail() || prompt_string == "%q" || prompt_string == "%Q") {
|
||||||
|
|
@ -231,7 +244,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
}
|
}
|
||||||
|
|
||||||
ReplGemma(model, pool, inner_pool, inference, app.verbosity,
|
ReplGemma(model, pool, inner_pool, inference, app.verbosity,
|
||||||
/*accept_token=*/[](int) { return true; });
|
/*accept_token=*/[](int) { return true; }, app.eot_line);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -62,6 +62,7 @@ class AppArgs : public ArgsBase<AppArgs> {
|
||||||
Path log; // output
|
Path log; // output
|
||||||
int verbosity;
|
int verbosity;
|
||||||
size_t num_threads;
|
size_t num_threads;
|
||||||
|
std::string eot_line;
|
||||||
|
|
||||||
template <class Visitor>
|
template <class Visitor>
|
||||||
void ForEach(const Visitor& visitor) {
|
void ForEach(const Visitor& visitor) {
|
||||||
|
|
@ -77,6 +78,11 @@ class AppArgs : public ArgsBase<AppArgs> {
|
||||||
"estimate of "
|
"estimate of "
|
||||||
"how many concurrent threads are supported.",
|
"how many concurrent threads are supported.",
|
||||||
2);
|
2);
|
||||||
|
visitor(eot_line, "eot_line", std::string(""),
|
||||||
|
"End of turn line. "
|
||||||
|
"When you specify this, the prompt will be all lines "
|
||||||
|
"before the line where only the given string appears.",
|
||||||
|
2);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue