diff --git a/gemma/run.cc b/gemma/run.cc index b7e8fa1..36d1bc2 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -156,7 +156,8 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, std::cout << token_text << std::flush; return true; }; - + // Flag to check if we should exit after processing non-interactive prompt + bool exit_after_generation = !inference.prompt.empty(); while (true) { // Loop until user quits. tokens_generated_this_turn = 0; @@ -179,14 +180,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, std::cout << "Use '%q' to quit.\n"; continue; } - std::vector prompt; - size_t prompt_size = 0; - size_t prefix_end = 0; - if constexpr (kVerboseLogTokens) { - for (int i = 0; i < prompt_size; ++i) { - fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]); - } - } // Set up runtime config. TimingInfo timing_info = {.verbosity = inference.verbosity}; @@ -195,29 +188,31 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, .stream_token = stream_token, .use_spinning = threading.spin}; inference.CopyTo(runtime_config); - + std::vector prompt; + size_t prompt_size = 0; + size_t prefix_end = 0; if (have_image) { prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(), abs_pos, prompt_string, image_tokens.BatchSize()); + runtime_config.image_tokens = &image_tokens; + prompt_size = prompt.size(); // The end of the prefix for prefix-LM style attention in Paligemma. // See Figure 2 of https://arxiv.org/abs/2407.07726. prefix_end = prompt_size; - // We need to look at all the tokens for the prefix. - runtime_config.prefill_tbatch_size = prompt_size; + + // REMOVED: Don't change prefill_tbatch_size for image handling + // runtime_config.prefill_tbatch_size = prompt_size; } else { prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(), abs_pos, prompt_string); prompt_size = prompt.size(); } - if (have_image) { - runtime_config.image_tokens = &image_tokens; - prompt_size = prompt.size(); - // The end of the prefix for prefix-LM style attention in Paligemma. - prefix_end = prompt_size; - // We need to look at all the tokens for the prefix. - runtime_config.prefill_tbatch_size = prompt_size; + if constexpr (kVerboseLogTokens) { + for (int i = 0; i < prompt_size; ++i) { + fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]); + } } // Generate until EOS or max_generated_tokens. @@ -229,7 +224,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, std::cout << "\n\n"; // Break the loop if in non-interactive mode - if (inference.prompt.empty()) { + if (exit_after_generation) { break; }