From ca4ee2b63f9f92eb572c088c600d00e89b776f3a Mon Sep 17 00:00:00 2001 From: RangerUFO Date: Wed, 26 Mar 2025 18:19:05 +0800 Subject: [PATCH 1/4] Refactor `WrapAndTokenize` to work properly with Gemma3 --- evals/benchmark_helper.h | 2 +- evals/gemma_test.cc | 10 ++- examples/hello_world/run.cc | 3 +- examples/simplified_gemma/gemma.hpp | 5 +- gemma/common.cc | 12 --- gemma/common.h | 3 - gemma/gemma.cc | 5 +- gemma/gemma.h | 2 + gemma/run.cc | 33 +++---- gemma/tokenizer.cc | 131 ++++++++++++++++++---------- gemma/tokenizer.h | 29 ++++-- 11 files changed, 139 insertions(+), 96 deletions(-) diff --git a/evals/benchmark_helper.h b/evals/benchmark_helper.h index f6e32c0..06523a1 100644 --- a/evals/benchmark_helper.h +++ b/evals/benchmark_helper.h @@ -69,7 +69,7 @@ class GemmaEnv { } std::vector WrapAndTokenize(std::string& input) const { - return gcpp::WrapAndTokenize(model_->Tokenizer(), model_->Info(), 0, input); + return gcpp::WrapAndTokenize(model_->Tokenizer(), model_->ChatTemplate(), model_->Info(), 0, input); } std::string StringFromTokens(const std::vector& tokens) const { diff --git a/evals/gemma_test.cc b/evals/gemma_test.cc index 7674c5e..2d3547f 100644 --- a/evals/gemma_test.cc +++ b/evals/gemma_test.cc @@ -178,16 +178,18 @@ TEST_F(GemmaTest, Multiturn) { TimingInfo timing_info{.verbosity = 0}; // First "say" something slightly unusual. std::string mutable_prompt = "I have a car and its color is turquoise."; - std::vector tokens = WrapAndTokenize(model->Tokenizer(), model->Info(), - abs_pos, mutable_prompt); + std::vector tokens = WrapAndTokenize(model->Tokenizer(), + model->ChatTemplate(), + model->Info(), abs_pos, + mutable_prompt); model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(), timing_info); // Note: we do not rewind any tokens here. If the model // produced one and WrapAndTokenize() inserts another one, it will just be // duplicated. mutable_prompt = "Please repeat all prior statements."; - tokens = WrapAndTokenize(model->Tokenizer(), model->Info(), abs_pos, - mutable_prompt); + tokens = WrapAndTokenize(model->Tokenizer(), model->ChatTemplate(), + model->Info(), abs_pos, mutable_prompt); // Reset the `response` string here, then check that the model actually has // access to the previous turn by asking to reproduce. response.clear(); diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index fb2fea3..96724be 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -75,7 +75,8 @@ int main(int argc, char** argv) { // Tokenize instructions. std::string prompt = "Write a greeting to the world."; const std::vector tokens = gcpp::WrapAndTokenize( - model.Tokenizer(), loader.Info(), generated, prompt); + model.Tokenizer(), model.ChatTemplate(), loader.Info(), + generated, prompt); const size_t prompt_size = tokens.size(); // This callback function gets invoked every time a token is generated diff --git a/examples/simplified_gemma/gemma.hpp b/examples/simplified_gemma/gemma.hpp index a2a7760..5047866 100644 --- a/examples/simplified_gemma/gemma.hpp +++ b/examples/simplified_gemma/gemma.hpp @@ -72,7 +72,8 @@ class SimplifiedGemma { size_t generated = 0; const std::vector tokens = gcpp::WrapAndTokenize( - model_.Tokenizer(), loader_.Info(), generated, prompt); + model_.Tokenizer(), model_.ChatTemplate(), loader_.Info(), + generated, prompt); const size_t prompt_size = tokens.size(); // This callback function gets invoked every time a token is generated @@ -115,4 +116,4 @@ class SimplifiedGemma { gcpp::KVCache kv_cache_; std::mt19937 gen_; std::string validation_error_; -}; \ No newline at end of file +}; diff --git a/gemma/common.cc b/gemma/common.cc index 0d8977b..dec9781 100644 --- a/gemma/common.cc +++ b/gemma/common.cc @@ -148,18 +148,6 @@ const char* ParseType(const std::string& type_string, Type& type) { return kErrorMessageBuffer.c_str(); } -void Wrap(const ModelInfo& info, size_t pos, std::string& prompt) { - - // Instruction-tuned models are trained to expect control tokens. - if (info.wrapping == PromptWrapping::GEMMA_IT) { - // Prepend "" if this is a multi-turn dialogue continuation. - const std::string start = (pos == 0) - ? "user\n" - : "\nuser\n"; - prompt = start + prompt + "\nmodel\n"; - } -} - float EmbeddingScaling(size_t model_dim) { // Round to bf16 to match Gemma's Embedder, which casts before mul. return hwy::ConvertScalarTo(hwy::ConvertScalarTo( diff --git a/gemma/common.h b/gemma/common.h index 984b0ba..bf4fc7e 100644 --- a/gemma/common.h +++ b/gemma/common.h @@ -43,9 +43,6 @@ const char* ParseType(const std::string& type_string, Type& type); const char* ModelString(Model model, PromptWrapping wrapping); const char* StringFromType(Type type); -// Wraps the given prompt using the expected control tokens for IT models. -void Wrap(const ModelInfo& info, size_t pos, std::string& prompt); - // Returns the scale value to use for the embedding (basically sqrt model_dim). float EmbeddingScaling(size_t model_dim); diff --git a/gemma/gemma.cc b/gemma/gemma.cc index bfc6534..1992d9c 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -40,7 +40,7 @@ namespace gcpp { Gemma::Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info, MatMulEnv& env) - : env_(env), tokenizer_(tokenizer_path) { + : env_(env), tokenizer_(tokenizer_path), chat_template_(tokenizer_) { model_.Load(weights, info.model, info.weight, info.wrapping, env_.parallel.Pools().Pool(0), /*tokenizer_proto=*/nullptr); @@ -51,10 +51,11 @@ Gemma::Gemma(const Path& weights, MatMulEnv& env) : env_(env) { model_.Load(weights, Model::UNKNOWN, Type::kUnknown, PromptWrapping::GEMMA_IT, env_.parallel.Pools().Pool(0), &tokenizer_proto); tokenizer_.Deserialize(tokenizer_proto); + chat_template_.Init(tokenizer_); } Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, MatMulEnv& env) - : env_(env), tokenizer_(std::move(tokenizer)) { + : env_(env), tokenizer_(std::move(tokenizer)), chat_template_(tokenizer_) { HWY_ASSERT(info.weight == Type::kF32); model_.Allocate(info.model, info.weight, env_.parallel.Pools().Pool(0)); } diff --git a/gemma/gemma.h b/gemma/gemma.h index ccda69c..de0cba1 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -213,6 +213,7 @@ class Gemma { .weight = model_.Config().weight}); } const GemmaTokenizer& Tokenizer() const { return tokenizer_; } + const GemmaChatTemplate& ChatTemplate() const { return chat_template_; } const ModelWeightsStorage& Weights() const { return model_; } ModelWeightsStorage& MutableWeights() { return model_; } void Save(const Path& weights, hwy::ThreadPool& pool) { @@ -256,6 +257,7 @@ class Gemma { MatMulEnv& env_; GemmaTokenizer tokenizer_; + GemmaChatTemplate chat_template_; // Type-erased so that this can be defined in the header. ModelWeightsStorage model_; }; diff --git a/gemma/run.cc b/gemma/run.cc index 254d13f..0d69011 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -162,16 +162,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, continue; } - // Wrap, tokenize and maybe log prompt tokens. - std::vector prompt = WrapAndTokenize( - model.Tokenizer(), model.Info(), abs_pos, prompt_string); - prompt_size = prompt.size(); - if constexpr (kVerboseLogTokens) { - for (int i = 0; i < prompt_size; ++i) { - fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]); - } - } - + std::vector prompt; // Set up runtime config. TimingInfo timing_info = {.verbosity = app.verbosity}; RuntimeConfig runtime_config = {.gen = &gen, @@ -182,22 +173,26 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, args.CopyTo(runtime_config); size_t prefix_end = 0; if (have_image) { + prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), + model.Info(), abs_pos, prompt_string, + image_tokens.BatchSize()); runtime_config.image_tokens = &image_tokens; - if (model.Info().wrapping == PromptWrapping::PALIGEMMA) { - prompt.insert(prompt.begin(), image_tokens.BatchSize(), 0); - } else if (model.Info().wrapping == PromptWrapping::GEMMA_VLM) { - size_t seq_len = model.GetModelConfig().vit_config.seq_len; - size_t pool_dim = model.GetModelConfig().vit_config.pool_dim; - prompt = - WrapVLM(model.Tokenizer(), model.Info(), abs_pos, prompt, - image_tokens.BatchSize(), seq_len / (pool_dim * pool_dim)); - } prompt_size = prompt.size(); // The end of the prefix for prefix-LM style attention in Paligemma. // See Figure 2 of https://arxiv.org/abs/2407.07726. prefix_end = prompt_size; // We need to look at all the tokens for the prefix. runtime_config.prefill_tbatch_size = prompt_size; + } else { + prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), + model.Info(), abs_pos, prompt_string); + prompt_size = prompt.size(); + } + + if constexpr (kVerboseLogTokens) { + for (int i = 0; i < prompt_size; ++i) { + fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]); + } } // Generate until EOS or max_generated_tokens. diff --git a/gemma/tokenizer.cc b/gemma/tokenizer.cc index e48abae..159be26 100644 --- a/gemma/tokenizer.cc +++ b/gemma/tokenizer.cc @@ -114,57 +114,96 @@ bool GemmaTokenizer::Decode(const std::vector& ids, return impl_->Decode(ids, detokenized); } -std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, - const ModelInfo& info, size_t pos, - std::string& prompt) { - Wrap(info, pos, prompt); - - std::vector tokens; - HWY_ASSERT(tokenizer.Encode(prompt, &tokens)); - // Both pre-trained and instruction-tuned require BOS as first token. - if (pos == 0) { - tokens.insert(tokens.begin(), BOS_ID); - } - - // PaliGemma separator. The SEP token "\n" is always tokenized separately. - if (info.wrapping == PromptWrapping::PALIGEMMA - // || info.wrapping == PromptWrapping::GEMMA_VLM - ) { - std::vector sep_tokens; - HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens)); - tokens.insert(tokens.end(), sep_tokens.begin(), sep_tokens.end()); - } - - return tokens; +GemmaChatTemplate::GemmaChatTemplate(const GemmaTokenizer& tokenizer) { + Init(tokenizer); } -std::vector WrapVLM(const GemmaTokenizer& tokenizer, const ModelInfo& info, - size_t pos, std::vector& tokens, - size_t image_batch_size, size_t max_image_batch_size) { - HWY_ASSERT(info.wrapping == PromptWrapping::GEMMA_VLM); - size_t num_images = hwy::DivCeil(image_batch_size, max_image_batch_size); +void GemmaChatTemplate::Init(const GemmaTokenizer& tokenizer) { + sot_user_.reserve(3); + HWY_ASSERT(tokenizer.Encode("user\n", &sot_user_)); + sot_model_.reserve(3); + HWY_ASSERT(tokenizer.Encode("model\n", &sot_model_)); + eot_.reserve(2); + HWY_ASSERT(tokenizer.Encode("\n", &eot_)); +} - std::vector sep_tokens; - HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens)); - - std::string begin_image_prompt = "\n\n"; - std::vector begin_image_tokens = - WrapAndTokenize(tokenizer, info, pos, begin_image_prompt); - - std::string end_image_prompt = "\n\n"; - std::vector end_image_tokens = - WrapAndTokenize(tokenizer, info, pos, end_image_prompt); - - for (size_t i = 0; i < num_images; ++i) { - tokens.insert(tokens.begin(), begin_image_tokens.begin(), - begin_image_tokens.end()); - tokens.insert(tokens.begin() + begin_image_tokens.size(), image_batch_size, - -2); - tokens.insert(tokens.begin() + begin_image_tokens.size() + image_batch_size, - end_image_tokens.begin(), end_image_tokens.end()); +std::vector GemmaChatTemplate::Apply(size_t pos, + const std::vector& ids) const { + HWY_ASSERT_M(!sot_user_.empty() && !sot_model_.empty() && !eot_.empty(), + "GemmaChatTemplate has not been initialized."); + std::vector out; + out.reserve(eot_.size() + + sot_user_.size() + + ids.size() + + eot_.size() + + sot_model_.size()); + if (pos > 0) { + out.insert(out.cend(), eot_.cbegin(), eot_.cend()); + } else { + out.push_back(BOS_ID); } + out.insert(out.cend(), sot_user_.cbegin(), sot_user_.cend()); + out.insert(out.cend(), ids.cbegin(), ids.cend()); + out.insert(out.cend(), eot_.cbegin(), eot_.cend()); + out.insert(out.cend(), sot_model_.cbegin(), sot_model_.cend()); + return out; +} - return tokens; +std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, + const GemmaChatTemplate& chat_template, + const ModelInfo& info, size_t pos, + const std::string& prompt) { + std::vector tokens; + HWY_ASSERT(tokenizer.Encode(prompt, &tokens)); + switch (info.wrapping) { + case PromptWrapping::GEMMA_IT: + case PromptWrapping::GEMMA_VLM: + return chat_template.Apply(pos, tokens); + default: + if (pos == 0) { + tokens.insert(tokens.cbegin(), BOS_ID); + } + return tokens; + } +} + +std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, + const GemmaChatTemplate& chat_template, + const ModelInfo& info, size_t pos, + const std::string& prompt, + size_t image_batch_size) { + std::vector text_part; + HWY_ASSERT(tokenizer.Encode(prompt, &text_part)); + std::vector tokens; + switch (info.wrapping) { + case PromptWrapping::PALIGEMMA: { + std::vector sep; + HWY_ASSERT(tokenizer.Encode("\n", &sep)); + tokens.reserve(image_batch_size + 1 + text_part.size() + sep.size()); + tokens.resize(image_batch_size, 0); + HWY_ASSERT(pos == 0); + tokens.push_back(BOS_ID); + tokens.insert(tokens.cend(), text_part.cbegin(), text_part.cend()); + tokens.insert(tokens.cend(), sep.cbegin(), sep.cend()); + return tokens; + } + case PromptWrapping::GEMMA_VLM: { + std::vector soi; + soi.reserve(2); + HWY_ASSERT(tokenizer.Encode("\n\n", &soi)); + std::vector eoi; + eoi.reserve(2); + HWY_ASSERT(tokenizer.Encode("\n\n", &eoi)); + tokens.reserve(text_part.size() + soi.size() + image_batch_size + eoi.size()); + tokens.insert(tokens.cend(), text_part.cbegin(), text_part.cend()); + tokens.insert(tokens.cend(), soi.cbegin(), soi.cend()); + tokens.insert(tokens.cend(), image_batch_size, -2); + tokens.insert(tokens.cend(), eoi.cbegin(), eoi.cend()); + return chat_template.Apply(pos, tokens); + } + default: + HWY_ASSERT_M(false, "Current variant does not support vision prompt."); + } } } // namespace gcpp diff --git a/gemma/tokenizer.h b/gemma/tokenizer.h index 0bbd8f4..a5d329d 100644 --- a/gemma/tokenizer.h +++ b/gemma/tokenizer.h @@ -54,13 +54,30 @@ class GemmaTokenizer { std::unique_ptr impl_; }; -std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, - const ModelInfo& info, size_t pos, - std::string& prompt); +class GemmaChatTemplate { + public: + GemmaChatTemplate() = default; + explicit GemmaChatTemplate(const GemmaTokenizer& tokenizer); -std::vector WrapVLM(const GemmaTokenizer& tokenizer, const ModelInfo& info, - size_t pos, std::vector& tokens, - size_t image_batch_size, size_t max_image_batch_size); + void Init(const GemmaTokenizer& tokenizer); + std::vector Apply(size_t pos, const std::vector& ids) const; + + private: + std::vector sot_user_; + std::vector sot_model_; + std::vector eot_; +}; + +std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, + const GemmaChatTemplate& chat_template, + const ModelInfo& info, size_t pos, + const std::string& prompt); + +std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, + const GemmaChatTemplate& chat_template, + const ModelInfo& info, size_t pos, + const std::string& prompt, + size_t image_batch_size); } // namespace gcpp From d1615b56b2549835a27929f04ca6e48c22267a13 Mon Sep 17 00:00:00 2001 From: RangerUFO Date: Wed, 26 Mar 2025 18:27:09 +0800 Subject: [PATCH 2/4] Fix the prompt wrapping of gemma3-1b again It seems that the previous fix was changed back due to a merge error. --- gemma/common.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gemma/common.cc b/gemma/common.cc index dec9781..2128159 100644 --- a/gemma/common.cc +++ b/gemma/common.cc @@ -80,7 +80,7 @@ constexpr PromptWrapping kPromptWrapping[] = { PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 3B 224/448 PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 10B 224/448 PromptWrapping::GEMMA_VLM, // Gemma3 4B - PromptWrapping::GEMMA_PT, // Gemma3 1B + PromptWrapping::GEMMA_IT, // Gemma3 1B PromptWrapping::GEMMA_VLM, // Gemma3 12B PromptWrapping::GEMMA_VLM, // Gemma3 27B }; From c39295f497b8fecf09ac7976d10cbe507a29bf12 Mon Sep 17 00:00:00 2001 From: RangerUFO Date: Thu, 27 Mar 2025 14:01:56 +0800 Subject: [PATCH 3/4] Inline the ctor of `GemmaChatTemplate` --- gemma/tokenizer.cc | 4 ---- gemma/tokenizer.h | 4 +++- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/gemma/tokenizer.cc b/gemma/tokenizer.cc index 159be26..275e836 100644 --- a/gemma/tokenizer.cc +++ b/gemma/tokenizer.cc @@ -114,10 +114,6 @@ bool GemmaTokenizer::Decode(const std::vector& ids, return impl_->Decode(ids, detokenized); } -GemmaChatTemplate::GemmaChatTemplate(const GemmaTokenizer& tokenizer) { - Init(tokenizer); -} - void GemmaChatTemplate::Init(const GemmaTokenizer& tokenizer) { sot_user_.reserve(3); HWY_ASSERT(tokenizer.Encode("user\n", &sot_user_)); diff --git a/gemma/tokenizer.h b/gemma/tokenizer.h index a5d329d..6cf5552 100644 --- a/gemma/tokenizer.h +++ b/gemma/tokenizer.h @@ -57,7 +57,9 @@ class GemmaTokenizer { class GemmaChatTemplate { public: GemmaChatTemplate() = default; - explicit GemmaChatTemplate(const GemmaTokenizer& tokenizer); + explicit GemmaChatTemplate(const GemmaTokenizer& tokenizer) { + Init(tokenizer); + } void Init(const GemmaTokenizer& tokenizer); std::vector Apply(size_t pos, const std::vector& ids) const; From cc2e14e65401190e301e10627cb6afcc18fe457d Mon Sep 17 00:00:00 2001 From: RangerUFO Date: Thu, 27 Mar 2025 15:57:53 +0800 Subject: [PATCH 4/4] Improve `GemmaChatTemplate` to handle vision prompt wrapping --- gemma/tokenizer.cc | 62 +++++++++++++++++++++++++++------------------- gemma/tokenizer.h | 7 ++++++ 2 files changed, 44 insertions(+), 25 deletions(-) diff --git a/gemma/tokenizer.cc b/gemma/tokenizer.cc index 275e836..39ade02 100644 --- a/gemma/tokenizer.cc +++ b/gemma/tokenizer.cc @@ -121,6 +121,11 @@ void GemmaChatTemplate::Init(const GemmaTokenizer& tokenizer) { HWY_ASSERT(tokenizer.Encode("model\n", &sot_model_)); eot_.reserve(2); HWY_ASSERT(tokenizer.Encode("\n", &eot_)); + HWY_ASSERT(tokenizer.Encode("\n", &pali_sep_)); + vlm_soi_.reserve(2); + HWY_ASSERT(tokenizer.Encode("\n\n", &vlm_soi_)); + vlm_eoi_.reserve(2); + HWY_ASSERT(tokenizer.Encode("\n\n", &vlm_eoi_)); } std::vector GemmaChatTemplate::Apply(size_t pos, @@ -145,6 +150,33 @@ std::vector GemmaChatTemplate::Apply(size_t pos, return out; } +std::vector GemmaChatTemplate::WrapPali(const std::vector& text_part, + size_t image_batch_size) const { + HWY_ASSERT_M(!pali_sep_.empty(), + "GemmaChatTemplate has not been initialized."); + std::vector out; + out.reserve(image_batch_size + 1 + text_part.size() + pali_sep_.size()); + out.resize(image_batch_size, 0); + out.push_back(BOS_ID); + out.insert(out.cend(), text_part.cbegin(), text_part.cend()); + out.insert(out.cend(), pali_sep_.cbegin(), pali_sep_.cend()); + return out; +} + +std::vector GemmaChatTemplate::WrapVLM(const std::vector& text_part, + size_t image_batch_size) const { + HWY_ASSERT_M(!vlm_soi_.empty() && !vlm_eoi_.empty(), + "GemmaChatTemplate has not been initialized."); + std::vector out; + out.reserve( + text_part.size() + vlm_soi_.size() + image_batch_size + vlm_eoi_.size()); + out.insert(out.cend(), text_part.cbegin(), text_part.cend()); + out.insert(out.cend(), vlm_soi_.cbegin(), vlm_soi_.cend()); + out.insert(out.cend(), image_batch_size, -2); + out.insert(out.cend(), vlm_eoi_.cbegin(), vlm_eoi_.cend()); + return out; +} + std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, const GemmaChatTemplate& chat_template, const ModelInfo& info, size_t pos, @@ -170,33 +202,13 @@ std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, size_t image_batch_size) { std::vector text_part; HWY_ASSERT(tokenizer.Encode(prompt, &text_part)); - std::vector tokens; switch (info.wrapping) { - case PromptWrapping::PALIGEMMA: { - std::vector sep; - HWY_ASSERT(tokenizer.Encode("\n", &sep)); - tokens.reserve(image_batch_size + 1 + text_part.size() + sep.size()); - tokens.resize(image_batch_size, 0); + case PromptWrapping::PALIGEMMA: HWY_ASSERT(pos == 0); - tokens.push_back(BOS_ID); - tokens.insert(tokens.cend(), text_part.cbegin(), text_part.cend()); - tokens.insert(tokens.cend(), sep.cbegin(), sep.cend()); - return tokens; - } - case PromptWrapping::GEMMA_VLM: { - std::vector soi; - soi.reserve(2); - HWY_ASSERT(tokenizer.Encode("\n\n", &soi)); - std::vector eoi; - eoi.reserve(2); - HWY_ASSERT(tokenizer.Encode("\n\n", &eoi)); - tokens.reserve(text_part.size() + soi.size() + image_batch_size + eoi.size()); - tokens.insert(tokens.cend(), text_part.cbegin(), text_part.cend()); - tokens.insert(tokens.cend(), soi.cbegin(), soi.cend()); - tokens.insert(tokens.cend(), image_batch_size, -2); - tokens.insert(tokens.cend(), eoi.cbegin(), eoi.cend()); - return chat_template.Apply(pos, tokens); - } + return chat_template.WrapPali(text_part, image_batch_size); + case PromptWrapping::GEMMA_VLM: + return chat_template.Apply(pos, chat_template.WrapVLM(text_part, + image_batch_size)); default: HWY_ASSERT_M(false, "Current variant does not support vision prompt."); } diff --git a/gemma/tokenizer.h b/gemma/tokenizer.h index 6cf5552..b4c511f 100644 --- a/gemma/tokenizer.h +++ b/gemma/tokenizer.h @@ -63,11 +63,18 @@ class GemmaChatTemplate { void Init(const GemmaTokenizer& tokenizer); std::vector Apply(size_t pos, const std::vector& ids) const; + std::vector WrapPali(const std::vector& text_part, + size_t image_batch_size) const; + std::vector WrapVLM(const std::vector& text_part, + size_t image_batch_size) const; private: std::vector sot_user_; std::vector sot_model_; std::vector eot_; + std::vector pali_sep_; + std::vector vlm_soi_; + std::vector vlm_eoi_; }; std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer,