mirror of https://github.com/google/gemma.cpp.git
Add --eot_line option
This commit is contained in:
parent
7698e3c3de
commit
1a95cf3274
17
run.cc
17
run.cc
|
|
@ -79,7 +79,9 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
|
||||
void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
|
||||
hwy::ThreadPool& inner_pool, const InferenceArgs& args,
|
||||
int verbosity, const gcpp::AcceptFunc& accept_token) {
|
||||
int verbosity, const gcpp::AcceptFunc& accept_token,
|
||||
std::string &eot_line
|
||||
) {
|
||||
PROFILER_ZONE("Gen.misc");
|
||||
int abs_pos = 0; // absolute token index over all turns
|
||||
int current_pos = 0; // token index within the current turn
|
||||
|
|
@ -137,7 +139,18 @@ void ReplGemma(gcpp::Gemma& model, hwy::ThreadPool& pool,
|
|||
if (verbosity >= 1) {
|
||||
std::cout << "> " << std::flush;
|
||||
}
|
||||
|
||||
if (eot_line.size() == 0) {
|
||||
std::getline(std::cin, prompt_string);
|
||||
} else {
|
||||
std::string line;
|
||||
while (std::getline(std::cin, line)) {
|
||||
if (line == eot_line) {
|
||||
break;
|
||||
}
|
||||
prompt_string += line + "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (std::cin.fail() || prompt_string == "%q" || prompt_string == "%Q") {
|
||||
|
|
@ -231,7 +244,7 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
}
|
||||
|
||||
ReplGemma(model, pool, inner_pool, inference, app.verbosity,
|
||||
/*accept_token=*/[](int) { return true; });
|
||||
/*accept_token=*/[](int) { return true; }, app.eot_line);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -62,6 +62,7 @@ class AppArgs : public ArgsBase<AppArgs> {
|
|||
Path log; // output
|
||||
int verbosity;
|
||||
size_t num_threads;
|
||||
std::string eot_line;
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
|
|
@ -77,6 +78,11 @@ class AppArgs : public ArgsBase<AppArgs> {
|
|||
"estimate of "
|
||||
"how many concurrent threads are supported.",
|
||||
2);
|
||||
visitor(eot_line, "eot_line", std::string(""),
|
||||
"End of turn line. "
|
||||
"When you specify this, the prompt will be all lines "
|
||||
"before the line where only the given string appears.",
|
||||
2);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue