Address review feedback: Fix prefill_tbatch_size and variable placement issues

This commit is contained in:
prajwalc22 2025-04-17 10:15:05 +05:30
parent 8246e49199
commit 27c28cc938
1 changed files with 15 additions and 20 deletions

View File

@ -156,7 +156,8 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
std::cout << token_text << std::flush;
return true;
};
// Flag to check if we should exit after processing non-interactive prompt
bool exit_after_generation = !inference.prompt.empty();
while (true) { // Loop until user quits.
tokens_generated_this_turn = 0;
@ -179,14 +180,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
std::cout << "Use '%q' to quit.\n";
continue;
}
std::vector<int> prompt;
size_t prompt_size = 0;
size_t prefix_end = 0;
if constexpr (kVerboseLogTokens) {
for (int i = 0; i < prompt_size; ++i) {
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
}
}
// Set up runtime config.
TimingInfo timing_info = {.verbosity = inference.verbosity};
@ -195,29 +188,31 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
.stream_token = stream_token,
.use_spinning = threading.spin};
inference.CopyTo(runtime_config);
std::vector<int> prompt;
size_t prompt_size = 0;
size_t prefix_end = 0;
if (have_image) {
prompt =
WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(),
abs_pos, prompt_string, image_tokens.BatchSize());
runtime_config.image_tokens = &image_tokens;
prompt_size = prompt.size();
// The end of the prefix for prefix-LM style attention in Paligemma.
// See Figure 2 of https://arxiv.org/abs/2407.07726.
prefix_end = prompt_size;
// We need to look at all the tokens for the prefix.
runtime_config.prefill_tbatch_size = prompt_size;
// REMOVED: Don't change prefill_tbatch_size for image handling
// runtime_config.prefill_tbatch_size = prompt_size;
} else {
prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(),
model.Info(), abs_pos, prompt_string);
prompt_size = prompt.size();
}
if (have_image) {
runtime_config.image_tokens = &image_tokens;
prompt_size = prompt.size();
// The end of the prefix for prefix-LM style attention in Paligemma.
prefix_end = prompt_size;
// We need to look at all the tokens for the prefix.
runtime_config.prefill_tbatch_size = prompt_size;
if constexpr (kVerboseLogTokens) {
for (int i = 0; i < prompt_size; ++i) {
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
}
}
// Generate until EOS or max_generated_tokens.
@ -229,7 +224,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
std::cout << "\n\n";
// Break the loop if in non-interactive mode
if (inference.prompt.empty()) {
if (exit_after_generation) {
break;
}