diff --git a/evals/benchmark_helper.h b/evals/benchmark_helper.h index f6e32c0..9d4773a 100644 --- a/evals/benchmark_helper.h +++ b/evals/benchmark_helper.h @@ -69,7 +69,8 @@ class GemmaEnv { } std::vector WrapAndTokenize(std::string& input) const { - return gcpp::WrapAndTokenize(model_->Tokenizer(), model_->Info(), 0, input); + return gcpp::WrapAndTokenize(model_->Tokenizer(), model_->ChatTemplate(), + model_->Info(), 0, input); } std::string StringFromTokens(const std::vector& tokens) const { diff --git a/evals/gemma_test.cc b/evals/gemma_test.cc index 7674c5e..c73bec6 100644 --- a/evals/gemma_test.cc +++ b/evals/gemma_test.cc @@ -178,22 +178,25 @@ TEST_F(GemmaTest, Multiturn) { TimingInfo timing_info{.verbosity = 0}; // First "say" something slightly unusual. std::string mutable_prompt = "I have a car and its color is turquoise."; - std::vector tokens = WrapAndTokenize(model->Tokenizer(), model->Info(), - abs_pos, mutable_prompt); + std::vector tokens = + WrapAndTokenize(model->Tokenizer(), model->ChatTemplate(), model->Info(), + abs_pos, mutable_prompt); + model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(), timing_info); // Note: we do not rewind any tokens here. If the model // produced one and WrapAndTokenize() inserts another one, it will just be // duplicated. mutable_prompt = "Please repeat all prior statements."; - tokens = WrapAndTokenize(model->Tokenizer(), model->Info(), abs_pos, - mutable_prompt); + tokens = WrapAndTokenize(model->Tokenizer(), model->ChatTemplate(), + model->Info(), abs_pos, mutable_prompt); + // Reset the `response` string here, then check that the model actually has // access to the previous turn by asking to reproduce. response.clear(); model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(), timing_info); - fprintf(stderr, "decoded: %s\n", response.c_str()); + fprintf(stderr, "decoded: '%s'\n", response.c_str()); bool remembered_turquoise = response.find("turquoise") != std::string::npos; // NOLINT bool remembered_car = response.find("car") != std::string::npos; // NOLINT diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index fb2fea3..8f65b15 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -74,8 +74,9 @@ int main(int argc, char** argv) { // Tokenize instructions. std::string prompt = "Write a greeting to the world."; - const std::vector tokens = gcpp::WrapAndTokenize( - model.Tokenizer(), loader.Info(), generated, prompt); + const std::vector tokens = + gcpp::WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), + loader.Info(), generated, prompt); const size_t prompt_size = tokens.size(); // This callback function gets invoked every time a token is generated diff --git a/examples/simplified_gemma/gemma.hpp b/examples/simplified_gemma/gemma.hpp index a2a7760..5047866 100644 --- a/examples/simplified_gemma/gemma.hpp +++ b/examples/simplified_gemma/gemma.hpp @@ -72,7 +72,8 @@ class SimplifiedGemma { size_t generated = 0; const std::vector tokens = gcpp::WrapAndTokenize( - model_.Tokenizer(), loader_.Info(), generated, prompt); + model_.Tokenizer(), model_.ChatTemplate(), loader_.Info(), + generated, prompt); const size_t prompt_size = tokens.size(); // This callback function gets invoked every time a token is generated @@ -115,4 +116,4 @@ class SimplifiedGemma { gcpp::KVCache kv_cache_; std::mt19937 gen_; std::string validation_error_; -}; \ No newline at end of file +}; diff --git a/gemma/common.cc b/gemma/common.cc index da782c5..9d5db95 100644 --- a/gemma/common.cc +++ b/gemma/common.cc @@ -25,6 +25,8 @@ #include #include "util/basics.h" // BF16 +// TODO: change include when PromptWrapping is moved. +#include "compression/shared.h" // PromptWrapping #include "hwy/base.h" namespace gcpp { @@ -79,7 +81,7 @@ constexpr PromptWrapping kPromptWrapping[] = { PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 3B 224/448 PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 10B 224/448 PromptWrapping::GEMMA_VLM, // Gemma3 4B - PromptWrapping::GEMMA_PT, // Gemma3 1B + PromptWrapping::GEMMA_IT, // Gemma3 1B PromptWrapping::GEMMA_VLM, // Gemma3 12B PromptWrapping::GEMMA_VLM, // Gemma3 27B }; diff --git a/gemma/common.h b/gemma/common.h index 8aa2112..d88a742 100644 --- a/gemma/common.h +++ b/gemma/common.h @@ -44,6 +44,7 @@ const char* ModelString(Model model, PromptWrapping wrapping); const char* StringFromType(Type type); // Wraps the given prompt using the expected control tokens for IT models. +// `GemmaChatTemplate` is preferred if a tokenized return value is fine. void Wrap(const ModelInfo& info, size_t pos, std::string& prompt); // Returns the scale value to use for the embedding (basically sqrt model_dim). diff --git a/gemma/gemma.cc b/gemma/gemma.cc index bfc6534..658ff66 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -44,6 +44,7 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights, model_.Load(weights, info.model, info.weight, info.wrapping, env_.parallel.Pools().Pool(0), /*tokenizer_proto=*/nullptr); + chat_template_.Init(tokenizer_, model_.Config().model); } Gemma::Gemma(const Path& weights, MatMulEnv& env) : env_(env) { @@ -51,10 +52,13 @@ Gemma::Gemma(const Path& weights, MatMulEnv& env) : env_(env) { model_.Load(weights, Model::UNKNOWN, Type::kUnknown, PromptWrapping::GEMMA_IT, env_.parallel.Pools().Pool(0), &tokenizer_proto); tokenizer_.Deserialize(tokenizer_proto); + chat_template_.Init(tokenizer_, model_.Config().model); } Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, MatMulEnv& env) - : env_(env), tokenizer_(std::move(tokenizer)) { + : env_(env), + tokenizer_(std::move(tokenizer)), + chat_template_(tokenizer_, info.model) { HWY_ASSERT(info.weight == Type::kF32); model_.Allocate(info.model, info.weight, env_.parallel.Pools().Pool(0)); } diff --git a/gemma/gemma.h b/gemma/gemma.h index ccda69c..de0cba1 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -213,6 +213,7 @@ class Gemma { .weight = model_.Config().weight}); } const GemmaTokenizer& Tokenizer() const { return tokenizer_; } + const GemmaChatTemplate& ChatTemplate() const { return chat_template_; } const ModelWeightsStorage& Weights() const { return model_; } ModelWeightsStorage& MutableWeights() { return model_; } void Save(const Path& weights, hwy::ThreadPool& pool) { @@ -256,6 +257,7 @@ class Gemma { MatMulEnv& env_; GemmaTokenizer tokenizer_; + GemmaChatTemplate chat_template_; // Type-erased so that this can be defined in the header. ModelWeightsStorage model_; }; diff --git a/gemma/run.cc b/gemma/run.cc index 254d13f..a437ae0 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -162,16 +162,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, continue; } - // Wrap, tokenize and maybe log prompt tokens. - std::vector prompt = WrapAndTokenize( - model.Tokenizer(), model.Info(), abs_pos, prompt_string); - prompt_size = prompt.size(); - if constexpr (kVerboseLogTokens) { - for (int i = 0; i < prompt_size; ++i) { - fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]); - } - } - // Set up runtime config. TimingInfo timing_info = {.verbosity = app.verbosity}; RuntimeConfig runtime_config = {.gen = &gen, @@ -181,23 +171,29 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, .use_spinning = app.spin}; args.CopyTo(runtime_config); size_t prefix_end = 0; + + std::vector prompt; if (have_image) { + prompt = + WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(), + abs_pos, prompt_string, image_tokens.BatchSize()); runtime_config.image_tokens = &image_tokens; - if (model.Info().wrapping == PromptWrapping::PALIGEMMA) { - prompt.insert(prompt.begin(), image_tokens.BatchSize(), 0); - } else if (model.Info().wrapping == PromptWrapping::GEMMA_VLM) { - size_t seq_len = model.GetModelConfig().vit_config.seq_len; - size_t pool_dim = model.GetModelConfig().vit_config.pool_dim; - prompt = - WrapVLM(model.Tokenizer(), model.Info(), abs_pos, prompt, - image_tokens.BatchSize(), seq_len / (pool_dim * pool_dim)); - } prompt_size = prompt.size(); // The end of the prefix for prefix-LM style attention in Paligemma. // See Figure 2 of https://arxiv.org/abs/2407.07726. prefix_end = prompt_size; // We need to look at all the tokens for the prefix. runtime_config.prefill_tbatch_size = prompt_size; + } else { + prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), + model.Info(), abs_pos, prompt_string); + prompt_size = prompt.size(); + } + + if constexpr (kVerboseLogTokens) { + for (int i = 0; i < prompt_size; ++i) { + fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]); + } } // Generate until EOS or max_generated_tokens. diff --git a/gemma/tokenizer.cc b/gemma/tokenizer.cc index e48abae..83f3429 100644 --- a/gemma/tokenizer.cc +++ b/gemma/tokenizer.cc @@ -100,71 +100,123 @@ void GemmaTokenizer::Deserialize(const std::string& tokenizer_proto) { bool GemmaTokenizer::Encode(const std::string& input, std::vector* pieces) const { - return impl_->Encode(input, pieces); + return impl_ && impl_->Encode(input, pieces); } bool GemmaTokenizer::Encode(const std::string& input, std::vector* ids) const { - return impl_->Encode(input, ids); + return impl_ && impl_->Encode(input, ids); } // Given a sequence of ids, decodes it into a detokenized output. bool GemmaTokenizer::Decode(const std::vector& ids, std::string* detokenized) const { - return impl_->Decode(ids, detokenized); + return impl_ && impl_->Decode(ids, detokenized); } -std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, - const ModelInfo& info, size_t pos, - std::string& prompt) { - Wrap(info, pos, prompt); +bool GemmaChatTemplate::Init(const GemmaTokenizer& tokenizer, Model model) { + sot_user_.reserve(3); + if (!tokenizer.Encode("user\n", &sot_user_)) return false; + sot_model_.reserve(3); + HWY_ASSERT(tokenizer.Encode("model\n", &sot_model_)); + eot_.reserve(2); + HWY_ASSERT(tokenizer.Encode("\n", &eot_)); + HWY_ASSERT(tokenizer.Encode("\n", &pali_sep_)); + vlm_soi_.reserve(2); + HWY_ASSERT(tokenizer.Encode("\n\n", &vlm_soi_)); + vlm_eoi_.reserve(2); + HWY_ASSERT(tokenizer.Encode("\n\n", &vlm_eoi_)); + return true; +} + +std::vector GemmaChatTemplate::Apply(size_t pos, + const std::vector& ids) const { + HWY_ASSERT_M(!sot_user_.empty() && !sot_model_.empty() && !eot_.empty(), + "GemmaChatTemplate has not been initialized."); + std::vector out; + out.reserve(eot_.size() + sot_user_.size() + ids.size() + eot_.size() + + sot_model_.size()); + + // Start with BOS, or prepend end_of_turn if this is a continuation. + if (pos == 0) { + out.push_back(BOS_ID); + } else { + out.insert(out.cend(), eot_.cbegin(), eot_.cend()); + } + // Start of user turn, user prompt, end of turn; then start of model turn. + out.insert(out.cend(), sot_user_.cbegin(), sot_user_.cend()); + out.insert(out.cend(), ids.cbegin(), ids.cend()); + out.insert(out.cend(), eot_.cbegin(), eot_.cend()); + out.insert(out.cend(), sot_model_.cbegin(), sot_model_.cend()); + return out; +} + +std::vector GemmaChatTemplate::WrapPali(const std::vector& text_part, + size_t image_batch_size) const { + HWY_ASSERT_M(!pali_sep_.empty(), + "GemmaChatTemplate has not been initialized."); + std::vector out; + out.reserve(image_batch_size + 1 + text_part.size() + pali_sep_.size()); + out.resize(image_batch_size, 0); + out.push_back(BOS_ID); + out.insert(out.cend(), text_part.cbegin(), text_part.cend()); + out.insert(out.cend(), pali_sep_.cbegin(), pali_sep_.cend()); + return out; +} + +std::vector GemmaChatTemplate::WrapVLM(const std::vector& text_part, + size_t image_batch_size) const { + HWY_ASSERT_M(!vlm_soi_.empty() && !vlm_eoi_.empty(), + "GemmaChatTemplate has not been initialized."); + std::vector out; + out.reserve(text_part.size() + vlm_soi_.size() + image_batch_size + + vlm_eoi_.size()); + out.insert(out.cend(), text_part.cbegin(), text_part.cend()); + out.insert(out.cend(), vlm_soi_.cbegin(), vlm_soi_.cend()); + out.insert(out.cend(), image_batch_size, -2); + out.insert(out.cend(), vlm_eoi_.cbegin(), vlm_eoi_.cend()); + return out; +} + +// Text +std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, + const GemmaChatTemplate& chat_template, + const ModelInfo& info, size_t pos, + const std::string& prompt) { std::vector tokens; HWY_ASSERT(tokenizer.Encode(prompt, &tokens)); - // Both pre-trained and instruction-tuned require BOS as first token. - if (pos == 0) { - tokens.insert(tokens.begin(), BOS_ID); - } - // PaliGemma separator. The SEP token "\n" is always tokenized separately. - if (info.wrapping == PromptWrapping::PALIGEMMA - // || info.wrapping == PromptWrapping::GEMMA_VLM - ) { - std::vector sep_tokens; - HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens)); - tokens.insert(tokens.end(), sep_tokens.begin(), sep_tokens.end()); + switch (info.wrapping) { + case PromptWrapping::GEMMA_IT: + case PromptWrapping::GEMMA_VLM: + return chat_template.Apply(pos, tokens); + default: + if (pos == 0) { + tokens.insert(tokens.cbegin(), BOS_ID); + } + return tokens; } - - return tokens; } -std::vector WrapVLM(const GemmaTokenizer& tokenizer, const ModelInfo& info, - size_t pos, std::vector& tokens, - size_t image_batch_size, size_t max_image_batch_size) { - HWY_ASSERT(info.wrapping == PromptWrapping::GEMMA_VLM); - size_t num_images = hwy::DivCeil(image_batch_size, max_image_batch_size); - - std::vector sep_tokens; - HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens)); - - std::string begin_image_prompt = "\n\n"; - std::vector begin_image_tokens = - WrapAndTokenize(tokenizer, info, pos, begin_image_prompt); - - std::string end_image_prompt = "\n\n"; - std::vector end_image_tokens = - WrapAndTokenize(tokenizer, info, pos, end_image_prompt); - - for (size_t i = 0; i < num_images; ++i) { - tokens.insert(tokens.begin(), begin_image_tokens.begin(), - begin_image_tokens.end()); - tokens.insert(tokens.begin() + begin_image_tokens.size(), image_batch_size, - -2); - tokens.insert(tokens.begin() + begin_image_tokens.size() + image_batch_size, - end_image_tokens.begin(), end_image_tokens.end()); +// Vision +std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, + const GemmaChatTemplate& chat_template, + const ModelInfo& info, size_t pos, + const std::string& prompt, + size_t image_batch_size) { + std::vector text_part; + HWY_ASSERT(tokenizer.Encode(prompt, &text_part)); + switch (info.wrapping) { + case PromptWrapping::PALIGEMMA: + HWY_ASSERT(pos == 0); + return chat_template.WrapPali(text_part, image_batch_size); + case PromptWrapping::GEMMA_VLM: + return chat_template.Apply( + pos, chat_template.WrapVLM(text_part, image_batch_size)); + default: + HWY_ASSERT_M(false, "Current variant does not support vision prompt."); } - - return tokens; } } // namespace gcpp diff --git a/gemma/tokenizer.h b/gemma/tokenizer.h index 0bbd8f4..ff8f91e 100644 --- a/gemma/tokenizer.h +++ b/gemma/tokenizer.h @@ -54,13 +54,43 @@ class GemmaTokenizer { std::unique_ptr impl_; }; -std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, - const ModelInfo& info, size_t pos, - std::string& prompt); +class GemmaChatTemplate { + public: + GemmaChatTemplate() = default; + explicit GemmaChatTemplate(const GemmaTokenizer& tokenizer, Model model) { + (void)Init(tokenizer, model); + } -std::vector WrapVLM(const GemmaTokenizer& tokenizer, const ModelInfo& info, - size_t pos, std::vector& tokens, - size_t image_batch_size, size_t max_image_batch_size); + // Returns false if the tokenizer is not available (as in optimize_test.cc). + bool Init(const GemmaTokenizer& tokenizer, Model model); + + // Given prompt tokens, this returns the wrapped prompt including BOS and + // any "start_of_turn" structure required by the model. + std::vector Apply(size_t pos, const std::vector& ids) const; + std::vector WrapPali(const std::vector& text_part, + size_t image_batch_size) const; + std::vector WrapVLM(const std::vector& text_part, + size_t image_batch_size) const; + + private: + std::vector sot_user_; + std::vector sot_model_; + std::vector eot_; + std::vector pali_sep_; + std::vector vlm_soi_; + std::vector vlm_eoi_; +}; + +std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, + const GemmaChatTemplate& chat_template, + const ModelInfo& info, size_t pos, + const std::string& prompt); + +std::vector WrapAndTokenize(const GemmaTokenizer& tokenizer, + const GemmaChatTemplate& chat_template, + const ModelInfo& info, size_t pos, + const std::string& prompt, + size_t image_batch_size); } // namespace gcpp