Merge pull request #573 from ufownl:bugfix/vit

PiperOrigin-RevId: 761425663
This commit is contained in:
Copybara-Service 2025-05-21 01:58:00 -07:00
commit 1ce89788ef
2 changed files with 11 additions and 10 deletions

View File

@ -1556,7 +1556,6 @@ void GenerateImageTokensT(const ModelStore& model,
prefill_runtime_config.prefill_tbatch_size = prefill_runtime_config.prefill_tbatch_size =
vit_config.seq_len / (vit_config.pool_dim * vit_config.pool_dim); vit_config.seq_len / (vit_config.pool_dim * vit_config.pool_dim);
Activations prefill_activations(vit_config, vit_config.seq_len, env); Activations prefill_activations(vit_config, vit_config.seq_len, env);
prefill_activations.SetBatchSize(prefill_runtime_config.prefill_tbatch_size);
// Weights are for the full PaliGemma model, not just the ViT part. // Weights are for the full PaliGemma model, not just the ViT part.
PrefillVit(weights, prefill_runtime_config, image, image_tokens, PrefillVit(weights, prefill_runtime_config, image, image_tokens,
prefill_activations); prefill_activations);

View File

@ -136,7 +136,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
++tokens_generated_this_turn; ++tokens_generated_this_turn;
if (in_prompt) { if (in_prompt) {
if (inference.verbosity >= 1) { if (inference.verbosity >= 1) {
std::cerr << "." << std::flush; std::cout << "." << std::flush;
} }
return true; return true;
} else if (config.IsEOS(token)) { } else if (config.IsEOS(token)) {
@ -188,7 +188,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
.use_spinning = threading.spin}; .use_spinning = threading.spin};
inference.CopyTo(runtime_config); inference.CopyTo(runtime_config);
std::vector<int> prompt; std::vector<int> prompt;
size_t prompt_size = 0;
size_t prefix_end = 0; size_t prefix_end = 0;
if (have_image) { if (have_image) {
prompt = WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(), prompt = WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(),
@ -196,12 +195,15 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
image_tokens.Rows()); image_tokens.Rows());
runtime_config.image_tokens = &image_tokens; runtime_config.image_tokens = &image_tokens;
prompt_size = prompt.size(); prompt_size = prompt.size();
// The end of the prefix for prefix-LM style attention in Paligemma. if (config.wrapping == PromptWrapping::PALIGEMMA) {
// See Figure 2 of https://arxiv.org/abs/2407.07726. // The end of the prefix for prefix-LM style attention in Paligemma.
prefix_end = prompt_size; // See Figure 2 of https://arxiv.org/abs/2407.07726.
prefix_end = prompt_size;
// REMOVED: Don't change prefill_tbatch_size for image handling // We need to look at all the tokens for the prefix.
// runtime_config.prefill_tbatch_size = prompt_size; // NOTE: Online softmax is on the roadmap, after which this requirement
// can be lifted.
runtime_config.prefill_tbatch_size = prompt_size;
}
} else { } else {
prompt = WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(), prompt = WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(),
config.wrapping, abs_pos, prompt_string); config.wrapping, abs_pos, prompt_string);