From 05b1cce9f7b28905d2e6a46a5f20a0164bdb836a Mon Sep 17 00:00:00 2001 From: Phil Culliton Date: Thu, 20 Mar 2025 12:27:44 -0700 Subject: [PATCH] Add support for a secondary EOS token PiperOrigin-RevId: 738898976 --- examples/hello_world/run.cc | 2 +- examples/simplified_gemma/gemma.hpp | 2 +- gemma/configs.cc | 4 ++++ gemma/configs.h | 6 ++++++ gemma/gemma-inl.h | 12 +++++++----- gemma/run.cc | 2 +- gemma/tokenizer.h | 1 + 7 files changed, 21 insertions(+), 8 deletions(-) diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index 9ef58ff..fb2fea3 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -83,7 +83,7 @@ int main(int argc, char** argv) { ++generated; if (generated < prompt_size) { // print feedback - } else if (token != gcpp::EOS_ID) { + } else if (!model.GetModelConfig().IsEOS(token)) { std::string token_text; HWY_ASSERT(model.Tokenizer().Decode({token}, &token_text)); std::cout << token_text << std::flush; diff --git a/examples/simplified_gemma/gemma.hpp b/examples/simplified_gemma/gemma.hpp index bbbb86a..a2a7760 100644 --- a/examples/simplified_gemma/gemma.hpp +++ b/examples/simplified_gemma/gemma.hpp @@ -80,7 +80,7 @@ class SimplifiedGemma { ++generated; if (generated < prompt_size) { // print feedback - } else if (token != gcpp::EOS_ID) { + } else if (!this->model_.GetModelConfig().IsEOS(token)) { std::string token_text; HWY_ASSERT(this->model_.Tokenizer().Decode({token}, &token_text)); std::cout << token_text << std::flush; diff --git a/gemma/configs.cc b/gemma/configs.cc index ebec631..d980b3b 100644 --- a/gemma/configs.cc +++ b/gemma/configs.cc @@ -195,6 +195,8 @@ static ModelConfig ConfigGemmaTiny() { config.attention_window_sizes = FixedAttentionWindowSizes<3>(32); // This is required for optimize_test to pass. config.final_cap = 30.0f; + config.eos_id = 11; + config.secondary_eos_id = 11; return config; } @@ -333,6 +335,8 @@ static ModelConfig ConfigBaseGemmaV3() { ModelConfig config = ConfigNoSSM(); config.att_cap = 0.0f; config.final_cap = 0.0f; + config.eos_id = 1; + config.secondary_eos_id = 106; return config; } diff --git a/gemma/configs.h b/gemma/configs.h index a5dba12..837e067 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -294,6 +294,8 @@ struct ModelConfig : public IFields { const char* Name() const override { return "ModelConfig"; } + bool IsEOS(int id) const { return (id == eos_id || id == secondary_eos_id); } + void VisitFields(IFieldsVisitor& visitor) override { visitor(model_family_version); visitor(model_name); @@ -315,6 +317,8 @@ struct ModelConfig : public IFields { visitor(norm_num_groups); visitor(vit_config); visitor(pool_dim); + visitor(eos_id); + visitor(secondary_eos_id); } // Major version of the model family. It is used as a fallback to distinguish @@ -341,6 +345,8 @@ struct ModelConfig : public IFields { // Dimensions related to image processing. VitConfig vit_config; uint32_t pool_dim = 1; // used only for VitConfig copy + int eos_id = 1; + int secondary_eos_id = 1; }; // Returns the config for the given model. diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index d87be54..ab25d53 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -1299,8 +1299,9 @@ static size_t MaxQueryLength(const QueriesPromptTokens& queries_prompt) { // Holds "is at end of stream" state for each query. class TokenStreamer { public: - explicit TokenStreamer(const RuntimeConfig& runtime_config) - : runtime_config_(runtime_config) {} + explicit TokenStreamer(const RuntimeConfig& runtime_config, + const ModelConfig& model_config) + : runtime_config_(runtime_config), model_config_(model_config) {} // Returns whether the query was already at, or has just reached, the end of // the stream: either via token == eos_id, or StreamToken returning false. @@ -1308,7 +1309,7 @@ class TokenStreamer { if (HWY_UNLIKELY(is_eos_.Get(query_idx))) return true; if (!runtime_config_.StreamToken(query_idx, pos, token, prob) || - token == runtime_config_.eos_id) { + model_config_.IsEOS(token)) { is_eos_.Set(query_idx); return true; } @@ -1318,6 +1319,7 @@ class TokenStreamer { private: const RuntimeConfig& runtime_config_; + const ModelConfig& model_config_; hwy::BitSet4096<> is_eos_; }; @@ -1425,7 +1427,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations, // Sanity check: prompts should not be empty, nor start with EOS. for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) { const PromptTokens& prompt = queries_prompt[query_idx]; - HWY_ASSERT(prompt.size() != 0 && prompt[0] != runtime_config.eos_id); + HWY_ASSERT(prompt.size() != 0 && !model.Config().IsEOS(prompt[0])); } const size_t num_queries = queries_prompt.size(); @@ -1469,7 +1471,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations, std::vector gen_tokens(num_queries); // Stream the last prompt token from each query and fill gen_tokens. - TokenStreamer token_streamer(runtime_config); + TokenStreamer token_streamer(runtime_config, model.Config()); for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { size_t last_token_pos_in_prompt = queries_mutable_pos[query_idx] - queries_pos_in[query_idx]; diff --git a/gemma/run.cc b/gemma/run.cc index 669f614..8c23c15 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -118,7 +118,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, // callback function invoked for each generated token. auto stream_token = [&](int token, float) { ++abs_pos; - if (token == EOS_ID) { + if (model.GetModelConfig().IsEOS(token)) { if (app.verbosity >= 2) { std::cout << "\n[ End ]\n"; } diff --git a/gemma/tokenizer.h b/gemma/tokenizer.h index a5c4c4f..0bbd8f4 100644 --- a/gemma/tokenizer.h +++ b/gemma/tokenizer.h @@ -29,6 +29,7 @@ namespace gcpp { // The tokenizer's end of sentence and beginning of sentence token ids. constexpr int EOS_ID = 1; +constexpr int SECONDARY_EOS_ID = 106; // for Gemma 3 constexpr int BOS_ID = 2; class GemmaTokenizer {