From cc2e14e65401190e301e10627cb6afcc18fe457d Mon Sep 17 00:00:00 2001 From: RangerUFO Date: Thu, 27 Mar 2025 15:57:53 +0800 Subject: [PATCH] Improve `GemmaChatTemplate` to handle vision prompt wrapping --- gemma/tokenizer.cc | 62 +++++++++++++++++++++++++++------------------- gemma/tokenizer.h | 7 ++++++ 2 files changed, 44 insertions(+), 25 deletions(-) diff --git a/gemma/tokenizer.cc b/gemma/tokenizer.cc index 275e836..39ade02 100644 --- a/gemma/tokenizer.cc +++ b/gemma/tokenizer.cc @@ -121,6 +121,11 @@ void GemmaChatTemplate::Init(const GemmaTokenizer& tokenizer) { HWY_ASSERT(tokenizer.Encode("model\n", &sot_model_)); eot_.reserve(2); HWY_ASSERT(tokenizer.Encode("\n", &eot_)); + HWY_ASSERT(tokenizer.Encode("\n", &pali_sep_)); + vlm_soi_.reserve(2); + HWY_ASSERT(tokenizer.Encode("\n\n", &vlm_soi_)); + vlm_eoi_.reserve(2); + HWY_ASSERT(tokenizer.Encode("\n\n", &vlm_eoi_)); } std::vector GemmaChatTemplate::Apply(size_t pos, @@ -145,6 +150,33 @@ std::vector GemmaChatTemplate::Apply(size_t pos, return out; } +std::vector GemmaChatTemplate::WrapPali(const std::vector& text_part, + size_t image_batch_size) const { + HWY_ASSERT_M(!pali_sep_.empty(), + "GemmaChatTemplate has not been initialized."); + std::vector out; + out.reserve(image_batch_size + 1 + text_part.size() + pali_sep_.size()); + out.resize(image_batch_size, 0); + out.push_back(BOS_ID); + out.insert(out.cend(), text_part.cbegin(), text_part.cend()); + out.insert(out.cend(), pali_sep_.cbegin(), pali_sep_.cend()); + return out; +} + +std::vector GemmaChatTemplate::WrapVLM(const std::vector& text_part, + size_t image_batch_size) const { + HWY_ASSERT_M(!vlm_soi_.empty() && !vlm_eoi_.empty(), + "GemmaChatTemplate has not been initialized."); + std::vector out; + out.reserve( + text_part.size() + vlm_soi_.size() + image_batch_size + vlm_eoi_.size()); + out.insert(out.cend(), text_part.cbegin(), text_part.cend()); + out.insert(out.cend(), vlm_soi_.cbegin(), vlm_soi_.cend()); + out.insert(out.cend(), image_batch_size, -2); + out.insert(out.cend(), vlm_eoi_.cbegin(), vlm_eoi_.cend()); + return out; +} + std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, const GemmaChatTemplate& chat_template, const ModelInfo& info, size_t pos, @@ -170,33 +202,13 @@ std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, size_t image_batch_size) { std::vector text_part; HWY_ASSERT(tokenizer.Encode(prompt, &text_part)); - std::vector tokens; switch (info.wrapping) { - case PromptWrapping::PALIGEMMA: { - std::vector sep; - HWY_ASSERT(tokenizer.Encode("\n", &sep)); - tokens.reserve(image_batch_size + 1 + text_part.size() + sep.size()); - tokens.resize(image_batch_size, 0); + case PromptWrapping::PALIGEMMA: HWY_ASSERT(pos == 0); - tokens.push_back(BOS_ID); - tokens.insert(tokens.cend(), text_part.cbegin(), text_part.cend()); - tokens.insert(tokens.cend(), sep.cbegin(), sep.cend()); - return tokens; - } - case PromptWrapping::GEMMA_VLM: { - std::vector soi; - soi.reserve(2); - HWY_ASSERT(tokenizer.Encode("\n\n", &soi)); - std::vector eoi; - eoi.reserve(2); - HWY_ASSERT(tokenizer.Encode("\n\n", &eoi)); - tokens.reserve(text_part.size() + soi.size() + image_batch_size + eoi.size()); - tokens.insert(tokens.cend(), text_part.cbegin(), text_part.cend()); - tokens.insert(tokens.cend(), soi.cbegin(), soi.cend()); - tokens.insert(tokens.cend(), image_batch_size, -2); - tokens.insert(tokens.cend(), eoi.cbegin(), eoi.cend()); - return chat_template.Apply(pos, tokens); - } + return chat_template.WrapPali(text_part, image_batch_size); + case PromptWrapping::GEMMA_VLM: + return chat_template.Apply(pos, chat_template.WrapVLM(text_part, + image_batch_size)); default: HWY_ASSERT_M(false, "Current variant does not support vision prompt."); } diff --git a/gemma/tokenizer.h b/gemma/tokenizer.h index 6cf5552..b4c511f 100644 --- a/gemma/tokenizer.h +++ b/gemma/tokenizer.h @@ -63,11 +63,18 @@ class GemmaChatTemplate { void Init(const GemmaTokenizer& tokenizer); std::vector Apply(size_t pos, const std::vector& ids) const; + std::vector WrapPali(const std::vector& text_part, + size_t image_batch_size) const; + std::vector WrapVLM(const std::vector& text_part, + size_t image_batch_size) const; private: std::vector sot_user_; std::vector sot_model_; std::vector eot_; + std::vector pali_sep_; + std::vector vlm_soi_; + std::vector vlm_eoi_; }; std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer,