Add support for a secondary EOS token

PiperOrigin-RevId: 738898976
This commit is contained in:
Phil Culliton 2025-03-20 12:27:44 -07:00 committed by Copybara-Service
parent 83219e3c68
commit 05b1cce9f7
7 changed files with 21 additions and 8 deletions

View File

@ -83,7 +83,7 @@ int main(int argc, char** argv) {
++generated; ++generated;
if (generated < prompt_size) { if (generated < prompt_size) {
// print feedback // print feedback
} else if (token != gcpp::EOS_ID) { } else if (!model.GetModelConfig().IsEOS(token)) {
std::string token_text; std::string token_text;
HWY_ASSERT(model.Tokenizer().Decode({token}, &token_text)); HWY_ASSERT(model.Tokenizer().Decode({token}, &token_text));
std::cout << token_text << std::flush; std::cout << token_text << std::flush;

View File

@ -80,7 +80,7 @@ class SimplifiedGemma {
++generated; ++generated;
if (generated < prompt_size) { if (generated < prompt_size) {
// print feedback // print feedback
} else if (token != gcpp::EOS_ID) { } else if (!this->model_.GetModelConfig().IsEOS(token)) {
std::string token_text; std::string token_text;
HWY_ASSERT(this->model_.Tokenizer().Decode({token}, &token_text)); HWY_ASSERT(this->model_.Tokenizer().Decode({token}, &token_text));
std::cout << token_text << std::flush; std::cout << token_text << std::flush;

View File

@ -195,6 +195,8 @@ static ModelConfig ConfigGemmaTiny() {
config.attention_window_sizes = FixedAttentionWindowSizes<3>(32); config.attention_window_sizes = FixedAttentionWindowSizes<3>(32);
// This is required for optimize_test to pass. // This is required for optimize_test to pass.
config.final_cap = 30.0f; config.final_cap = 30.0f;
config.eos_id = 11;
config.secondary_eos_id = 11;
return config; return config;
} }
@ -333,6 +335,8 @@ static ModelConfig ConfigBaseGemmaV3() {
ModelConfig config = ConfigNoSSM(); ModelConfig config = ConfigNoSSM();
config.att_cap = 0.0f; config.att_cap = 0.0f;
config.final_cap = 0.0f; config.final_cap = 0.0f;
config.eos_id = 1;
config.secondary_eos_id = 106;
return config; return config;
} }

View File

@ -294,6 +294,8 @@ struct ModelConfig : public IFields {
const char* Name() const override { return "ModelConfig"; } const char* Name() const override { return "ModelConfig"; }
bool IsEOS(int id) const { return (id == eos_id || id == secondary_eos_id); }
void VisitFields(IFieldsVisitor& visitor) override { void VisitFields(IFieldsVisitor& visitor) override {
visitor(model_family_version); visitor(model_family_version);
visitor(model_name); visitor(model_name);
@ -315,6 +317,8 @@ struct ModelConfig : public IFields {
visitor(norm_num_groups); visitor(norm_num_groups);
visitor(vit_config); visitor(vit_config);
visitor(pool_dim); visitor(pool_dim);
visitor(eos_id);
visitor(secondary_eos_id);
} }
// Major version of the model family. It is used as a fallback to distinguish // Major version of the model family. It is used as a fallback to distinguish
@ -341,6 +345,8 @@ struct ModelConfig : public IFields {
// Dimensions related to image processing. // Dimensions related to image processing.
VitConfig vit_config; VitConfig vit_config;
uint32_t pool_dim = 1; // used only for VitConfig copy uint32_t pool_dim = 1; // used only for VitConfig copy
int eos_id = 1;
int secondary_eos_id = 1;
}; };
// Returns the config for the given model. // Returns the config for the given model.

View File

@ -1299,8 +1299,9 @@ static size_t MaxQueryLength(const QueriesPromptTokens& queries_prompt) {
// Holds "is at end of stream" state for each query. // Holds "is at end of stream" state for each query.
class TokenStreamer { class TokenStreamer {
public: public:
explicit TokenStreamer(const RuntimeConfig& runtime_config) explicit TokenStreamer(const RuntimeConfig& runtime_config,
: runtime_config_(runtime_config) {} const ModelConfig& model_config)
: runtime_config_(runtime_config), model_config_(model_config) {}
// Returns whether the query was already at, or has just reached, the end of // Returns whether the query was already at, or has just reached, the end of
// the stream: either via token == eos_id, or StreamToken returning false. // the stream: either via token == eos_id, or StreamToken returning false.
@ -1308,7 +1309,7 @@ class TokenStreamer {
if (HWY_UNLIKELY(is_eos_.Get(query_idx))) return true; if (HWY_UNLIKELY(is_eos_.Get(query_idx))) return true;
if (!runtime_config_.StreamToken(query_idx, pos, token, prob) || if (!runtime_config_.StreamToken(query_idx, pos, token, prob) ||
token == runtime_config_.eos_id) { model_config_.IsEOS(token)) {
is_eos_.Set(query_idx); is_eos_.Set(query_idx);
return true; return true;
} }
@ -1318,6 +1319,7 @@ class TokenStreamer {
private: private:
const RuntimeConfig& runtime_config_; const RuntimeConfig& runtime_config_;
const ModelConfig& model_config_;
hwy::BitSet4096<> is_eos_; hwy::BitSet4096<> is_eos_;
}; };
@ -1425,7 +1427,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
// Sanity check: prompts should not be empty, nor start with EOS. // Sanity check: prompts should not be empty, nor start with EOS.
for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) { for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) {
const PromptTokens& prompt = queries_prompt[query_idx]; const PromptTokens& prompt = queries_prompt[query_idx];
HWY_ASSERT(prompt.size() != 0 && prompt[0] != runtime_config.eos_id); HWY_ASSERT(prompt.size() != 0 && !model.Config().IsEOS(prompt[0]));
} }
const size_t num_queries = queries_prompt.size(); const size_t num_queries = queries_prompt.size();
@ -1469,7 +1471,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
std::vector<int> gen_tokens(num_queries); std::vector<int> gen_tokens(num_queries);
// Stream the last prompt token from each query and fill gen_tokens. // Stream the last prompt token from each query and fill gen_tokens.
TokenStreamer token_streamer(runtime_config); TokenStreamer token_streamer(runtime_config, model.Config());
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
size_t last_token_pos_in_prompt = size_t last_token_pos_in_prompt =
queries_mutable_pos[query_idx] - queries_pos_in[query_idx]; queries_mutable_pos[query_idx] - queries_pos_in[query_idx];

View File

@ -118,7 +118,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
// callback function invoked for each generated token. // callback function invoked for each generated token.
auto stream_token = [&](int token, float) { auto stream_token = [&](int token, float) {
++abs_pos; ++abs_pos;
if (token == EOS_ID) { if (model.GetModelConfig().IsEOS(token)) {
if (app.verbosity >= 2) { if (app.verbosity >= 2) {
std::cout << "\n[ End ]\n"; std::cout << "\n[ End ]\n";
} }

View File

@ -29,6 +29,7 @@ namespace gcpp {
// The tokenizer's end of sentence and beginning of sentence token ids. // The tokenizer's end of sentence and beginning of sentence token ids.
constexpr int EOS_ID = 1; constexpr int EOS_ID = 1;
constexpr int SECONDARY_EOS_ID = 106; // for Gemma 3
constexpr int BOS_ID = 2; constexpr int BOS_ID = 2;
class GemmaTokenizer { class GemmaTokenizer {