From 6debdbe34195dbe9b79a91015b777ad1bb1f1f90 Mon Sep 17 00:00:00 2001 From: RangerUFO Date: Tue, 20 May 2025 14:20:57 +0800 Subject: [PATCH] Minor fixes for ViT --- gemma/gemma-inl.h | 1 - gemma/run.cc | 20 +++++++++++--------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 21e388f..a1ed465 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -1556,7 +1556,6 @@ void GenerateImageTokensT(const ModelStore& model, prefill_runtime_config.prefill_tbatch_size = vit_config.seq_len / (vit_config.pool_dim * vit_config.pool_dim); Activations prefill_activations(vit_config, vit_config.seq_len, env); - prefill_activations.SetBatchSize(prefill_runtime_config.prefill_tbatch_size); // Weights are for the full PaliGemma model, not just the ViT part. PrefillVit(weights, prefill_runtime_config, image, image_tokens, prefill_activations); diff --git a/gemma/run.cc b/gemma/run.cc index 88c2dcf..cba3500 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -136,7 +136,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, ++tokens_generated_this_turn; if (in_prompt) { if (inference.verbosity >= 1) { - std::cerr << "." << std::flush; + std::cout << "." << std::flush; } return true; } else if (config.IsEOS(token)) { @@ -188,7 +188,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, .use_spinning = threading.spin}; inference.CopyTo(runtime_config); std::vector prompt; - size_t prompt_size = 0; size_t prefix_end = 0; if (have_image) { prompt = WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(), @@ -196,12 +195,15 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, image_tokens.Rows()); runtime_config.image_tokens = &image_tokens; prompt_size = prompt.size(); - // The end of the prefix for prefix-LM style attention in Paligemma. - // See Figure 2 of https://arxiv.org/abs/2407.07726. - prefix_end = prompt_size; - - // REMOVED: Don't change prefill_tbatch_size for image handling - // runtime_config.prefill_tbatch_size = prompt_size; + if (config.wrapping == PromptWrapping::PALIGEMMA) { + // The end of the prefix for prefix-LM style attention in Paligemma. + // See Figure 2 of https://arxiv.org/abs/2407.07726. + prefix_end = prompt_size; + // We need to look at all the tokens for the prefix. + // NOTE: Online softmax is on the roadmap, after which this requirement + // can be lifted. + runtime_config.prefill_tbatch_size = prompt_size; + } } else { prompt = WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(), config.wrapping, abs_pos, prompt_string); @@ -308,4 +310,4 @@ int main(int argc, char** argv) { } PROFILER_PRINT_RESULTS(); // Must call outside the zone above. return 0; -} \ No newline at end of file +}