mirror of https://github.com/google/gemma.cpp.git
Minor fixes for ViT
This commit is contained in:
parent
cb188d4a0e
commit
6debdbe341
|
|
@ -1556,7 +1556,6 @@ void GenerateImageTokensT(const ModelStore& model,
|
||||||
prefill_runtime_config.prefill_tbatch_size =
|
prefill_runtime_config.prefill_tbatch_size =
|
||||||
vit_config.seq_len / (vit_config.pool_dim * vit_config.pool_dim);
|
vit_config.seq_len / (vit_config.pool_dim * vit_config.pool_dim);
|
||||||
Activations prefill_activations(vit_config, vit_config.seq_len, env);
|
Activations prefill_activations(vit_config, vit_config.seq_len, env);
|
||||||
prefill_activations.SetBatchSize(prefill_runtime_config.prefill_tbatch_size);
|
|
||||||
// Weights are for the full PaliGemma model, not just the ViT part.
|
// Weights are for the full PaliGemma model, not just the ViT part.
|
||||||
PrefillVit(weights, prefill_runtime_config, image, image_tokens,
|
PrefillVit(weights, prefill_runtime_config, image, image_tokens,
|
||||||
prefill_activations);
|
prefill_activations);
|
||||||
|
|
|
||||||
20
gemma/run.cc
20
gemma/run.cc
|
|
@ -136,7 +136,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
++tokens_generated_this_turn;
|
++tokens_generated_this_turn;
|
||||||
if (in_prompt) {
|
if (in_prompt) {
|
||||||
if (inference.verbosity >= 1) {
|
if (inference.verbosity >= 1) {
|
||||||
std::cerr << "." << std::flush;
|
std::cout << "." << std::flush;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
} else if (config.IsEOS(token)) {
|
} else if (config.IsEOS(token)) {
|
||||||
|
|
@ -188,7 +188,6 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
.use_spinning = threading.spin};
|
.use_spinning = threading.spin};
|
||||||
inference.CopyTo(runtime_config);
|
inference.CopyTo(runtime_config);
|
||||||
std::vector<int> prompt;
|
std::vector<int> prompt;
|
||||||
size_t prompt_size = 0;
|
|
||||||
size_t prefix_end = 0;
|
size_t prefix_end = 0;
|
||||||
if (have_image) {
|
if (have_image) {
|
||||||
prompt = WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(),
|
prompt = WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(),
|
||||||
|
|
@ -196,12 +195,15 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
image_tokens.Rows());
|
image_tokens.Rows());
|
||||||
runtime_config.image_tokens = &image_tokens;
|
runtime_config.image_tokens = &image_tokens;
|
||||||
prompt_size = prompt.size();
|
prompt_size = prompt.size();
|
||||||
// The end of the prefix for prefix-LM style attention in Paligemma.
|
if (config.wrapping == PromptWrapping::PALIGEMMA) {
|
||||||
// See Figure 2 of https://arxiv.org/abs/2407.07726.
|
// The end of the prefix for prefix-LM style attention in Paligemma.
|
||||||
prefix_end = prompt_size;
|
// See Figure 2 of https://arxiv.org/abs/2407.07726.
|
||||||
|
prefix_end = prompt_size;
|
||||||
// REMOVED: Don't change prefill_tbatch_size for image handling
|
// We need to look at all the tokens for the prefix.
|
||||||
// runtime_config.prefill_tbatch_size = prompt_size;
|
// NOTE: Online softmax is on the roadmap, after which this requirement
|
||||||
|
// can be lifted.
|
||||||
|
runtime_config.prefill_tbatch_size = prompt_size;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
prompt = WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(),
|
prompt = WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(),
|
||||||
config.wrapping, abs_pos, prompt_string);
|
config.wrapping, abs_pos, prompt_string);
|
||||||
|
|
@ -308,4 +310,4 @@ int main(int argc, char** argv) {
|
||||||
}
|
}
|
||||||
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
|
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue