Address review feedback: Fix prefill_tbatch_size and variable placement issues

This commit is contained in:
prajwalc22 2025-04-17 10:15:05 +05:30
parent 8246e49199
commit 27c28cc938
1 changed files with 15 additions and 20 deletions

View File

@ -156,7 +156,8 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
std::cout << token_text << std::flush; std::cout << token_text << std::flush;
return true; return true;
}; };
// Flag to check if we should exit after processing non-interactive prompt
bool exit_after_generation = !inference.prompt.empty();
while (true) { // Loop until user quits. while (true) { // Loop until user quits.
tokens_generated_this_turn = 0; tokens_generated_this_turn = 0;
@ -179,14 +180,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
std::cout << "Use '%q' to quit.\n"; std::cout << "Use '%q' to quit.\n";
continue; continue;
} }
std::vector<int> prompt;
size_t prompt_size = 0;
size_t prefix_end = 0;
if constexpr (kVerboseLogTokens) {
for (int i = 0; i < prompt_size; ++i) {
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
}
}
// Set up runtime config. // Set up runtime config.
TimingInfo timing_info = {.verbosity = inference.verbosity}; TimingInfo timing_info = {.verbosity = inference.verbosity};
@ -195,29 +188,31 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
.stream_token = stream_token, .stream_token = stream_token,
.use_spinning = threading.spin}; .use_spinning = threading.spin};
inference.CopyTo(runtime_config); inference.CopyTo(runtime_config);
std::vector<int> prompt;
size_t prompt_size = 0;
size_t prefix_end = 0;
if (have_image) { if (have_image) {
prompt = prompt =
WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(), WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(),
abs_pos, prompt_string, image_tokens.BatchSize()); abs_pos, prompt_string, image_tokens.BatchSize());
runtime_config.image_tokens = &image_tokens;
prompt_size = prompt.size();
// The end of the prefix for prefix-LM style attention in Paligemma. // The end of the prefix for prefix-LM style attention in Paligemma.
// See Figure 2 of https://arxiv.org/abs/2407.07726. // See Figure 2 of https://arxiv.org/abs/2407.07726.
prefix_end = prompt_size; prefix_end = prompt_size;
// We need to look at all the tokens for the prefix.
runtime_config.prefill_tbatch_size = prompt_size; // REMOVED: Don't change prefill_tbatch_size for image handling
// runtime_config.prefill_tbatch_size = prompt_size;
} else { } else {
prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(),
model.Info(), abs_pos, prompt_string); model.Info(), abs_pos, prompt_string);
prompt_size = prompt.size(); prompt_size = prompt.size();
} }
if (have_image) { if constexpr (kVerboseLogTokens) {
runtime_config.image_tokens = &image_tokens; for (int i = 0; i < prompt_size; ++i) {
prompt_size = prompt.size(); fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
// The end of the prefix for prefix-LM style attention in Paligemma. }
prefix_end = prompt_size;
// We need to look at all the tokens for the prefix.
runtime_config.prefill_tbatch_size = prompt_size;
} }
// Generate until EOS or max_generated_tokens. // Generate until EOS or max_generated_tokens.
@ -229,7 +224,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
std::cout << "\n\n"; std::cout << "\n\n";
// Break the loop if in non-interactive mode // Break the loop if in non-interactive mode
if (inference.prompt.empty()) { if (exit_after_generation) {
break; break;
} }