mirror of https://github.com/google/gemma.cpp.git
Address review feedback: Fix prefill_tbatch_size and variable placement issues
This commit is contained in:
parent
8246e49199
commit
27c28cc938
35
gemma/run.cc
35
gemma/run.cc
|
|
@ -156,7 +156,8 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
std::cout << token_text << std::flush;
|
||||
return true;
|
||||
};
|
||||
|
||||
// Flag to check if we should exit after processing non-interactive prompt
|
||||
bool exit_after_generation = !inference.prompt.empty();
|
||||
while (true) { // Loop until user quits.
|
||||
tokens_generated_this_turn = 0;
|
||||
|
||||
|
|
@ -179,14 +180,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
std::cout << "Use '%q' to quit.\n";
|
||||
continue;
|
||||
}
|
||||
std::vector<int> prompt;
|
||||
size_t prompt_size = 0;
|
||||
size_t prefix_end = 0;
|
||||
if constexpr (kVerboseLogTokens) {
|
||||
for (int i = 0; i < prompt_size; ++i) {
|
||||
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Set up runtime config.
|
||||
TimingInfo timing_info = {.verbosity = inference.verbosity};
|
||||
|
|
@ -195,29 +188,31 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
.stream_token = stream_token,
|
||||
.use_spinning = threading.spin};
|
||||
inference.CopyTo(runtime_config);
|
||||
|
||||
std::vector<int> prompt;
|
||||
size_t prompt_size = 0;
|
||||
size_t prefix_end = 0;
|
||||
if (have_image) {
|
||||
prompt =
|
||||
WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(),
|
||||
abs_pos, prompt_string, image_tokens.BatchSize());
|
||||
runtime_config.image_tokens = &image_tokens;
|
||||
prompt_size = prompt.size();
|
||||
// The end of the prefix for prefix-LM style attention in Paligemma.
|
||||
// See Figure 2 of https://arxiv.org/abs/2407.07726.
|
||||
prefix_end = prompt_size;
|
||||
// We need to look at all the tokens for the prefix.
|
||||
runtime_config.prefill_tbatch_size = prompt_size;
|
||||
|
||||
// REMOVED: Don't change prefill_tbatch_size for image handling
|
||||
// runtime_config.prefill_tbatch_size = prompt_size;
|
||||
} else {
|
||||
prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(),
|
||||
model.Info(), abs_pos, prompt_string);
|
||||
prompt_size = prompt.size();
|
||||
}
|
||||
|
||||
if (have_image) {
|
||||
runtime_config.image_tokens = &image_tokens;
|
||||
prompt_size = prompt.size();
|
||||
// The end of the prefix for prefix-LM style attention in Paligemma.
|
||||
prefix_end = prompt_size;
|
||||
// We need to look at all the tokens for the prefix.
|
||||
runtime_config.prefill_tbatch_size = prompt_size;
|
||||
if constexpr (kVerboseLogTokens) {
|
||||
for (int i = 0; i < prompt_size; ++i) {
|
||||
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Generate until EOS or max_generated_tokens.
|
||||
|
|
@ -229,7 +224,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
|||
std::cout << "\n\n";
|
||||
|
||||
// Break the loop if in non-interactive mode
|
||||
if (inference.prompt.empty()) {
|
||||
if (exit_after_generation) {
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue