mirror of https://github.com/google/gemma.cpp.git
Address review feedback: Fix prefill_tbatch_size and variable placement issues
This commit is contained in:
parent
8246e49199
commit
27c28cc938
35
gemma/run.cc
35
gemma/run.cc
|
|
@ -156,7 +156,8 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
std::cout << token_text << std::flush;
|
std::cout << token_text << std::flush;
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
// Flag to check if we should exit after processing non-interactive prompt
|
||||||
|
bool exit_after_generation = !inference.prompt.empty();
|
||||||
while (true) { // Loop until user quits.
|
while (true) { // Loop until user quits.
|
||||||
tokens_generated_this_turn = 0;
|
tokens_generated_this_turn = 0;
|
||||||
|
|
||||||
|
|
@ -179,14 +180,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
std::cout << "Use '%q' to quit.\n";
|
std::cout << "Use '%q' to quit.\n";
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
std::vector<int> prompt;
|
|
||||||
size_t prompt_size = 0;
|
|
||||||
size_t prefix_end = 0;
|
|
||||||
if constexpr (kVerboseLogTokens) {
|
|
||||||
for (int i = 0; i < prompt_size; ++i) {
|
|
||||||
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up runtime config.
|
// Set up runtime config.
|
||||||
TimingInfo timing_info = {.verbosity = inference.verbosity};
|
TimingInfo timing_info = {.verbosity = inference.verbosity};
|
||||||
|
|
@ -195,29 +188,31 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
.stream_token = stream_token,
|
.stream_token = stream_token,
|
||||||
.use_spinning = threading.spin};
|
.use_spinning = threading.spin};
|
||||||
inference.CopyTo(runtime_config);
|
inference.CopyTo(runtime_config);
|
||||||
|
std::vector<int> prompt;
|
||||||
|
size_t prompt_size = 0;
|
||||||
|
size_t prefix_end = 0;
|
||||||
if (have_image) {
|
if (have_image) {
|
||||||
prompt =
|
prompt =
|
||||||
WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(),
|
WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(),
|
||||||
abs_pos, prompt_string, image_tokens.BatchSize());
|
abs_pos, prompt_string, image_tokens.BatchSize());
|
||||||
|
runtime_config.image_tokens = &image_tokens;
|
||||||
|
prompt_size = prompt.size();
|
||||||
// The end of the prefix for prefix-LM style attention in Paligemma.
|
// The end of the prefix for prefix-LM style attention in Paligemma.
|
||||||
// See Figure 2 of https://arxiv.org/abs/2407.07726.
|
// See Figure 2 of https://arxiv.org/abs/2407.07726.
|
||||||
prefix_end = prompt_size;
|
prefix_end = prompt_size;
|
||||||
// We need to look at all the tokens for the prefix.
|
|
||||||
runtime_config.prefill_tbatch_size = prompt_size;
|
// REMOVED: Don't change prefill_tbatch_size for image handling
|
||||||
|
// runtime_config.prefill_tbatch_size = prompt_size;
|
||||||
} else {
|
} else {
|
||||||
prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(),
|
prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(),
|
||||||
model.Info(), abs_pos, prompt_string);
|
model.Info(), abs_pos, prompt_string);
|
||||||
prompt_size = prompt.size();
|
prompt_size = prompt.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (have_image) {
|
if constexpr (kVerboseLogTokens) {
|
||||||
runtime_config.image_tokens = &image_tokens;
|
for (int i = 0; i < prompt_size; ++i) {
|
||||||
prompt_size = prompt.size();
|
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
|
||||||
// The end of the prefix for prefix-LM style attention in Paligemma.
|
}
|
||||||
prefix_end = prompt_size;
|
|
||||||
// We need to look at all the tokens for the prefix.
|
|
||||||
runtime_config.prefill_tbatch_size = prompt_size;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate until EOS or max_generated_tokens.
|
// Generate until EOS or max_generated_tokens.
|
||||||
|
|
@ -229,7 +224,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
std::cout << "\n\n";
|
std::cout << "\n\n";
|
||||||
|
|
||||||
// Break the loop if in non-interactive mode
|
// Break the loop if in non-interactive mode
|
||||||
if (inference.prompt.empty()) {
|
if (exit_after_generation) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue