mirror of https://github.com/google/gemma.cpp.git
Add support for a secondary EOS token
PiperOrigin-RevId: 738898976
This commit is contained in:
parent
83219e3c68
commit
05b1cce9f7
|
|
@ -83,7 +83,7 @@ int main(int argc, char** argv) {
|
||||||
++generated;
|
++generated;
|
||||||
if (generated < prompt_size) {
|
if (generated < prompt_size) {
|
||||||
// print feedback
|
// print feedback
|
||||||
} else if (token != gcpp::EOS_ID) {
|
} else if (!model.GetModelConfig().IsEOS(token)) {
|
||||||
std::string token_text;
|
std::string token_text;
|
||||||
HWY_ASSERT(model.Tokenizer().Decode({token}, &token_text));
|
HWY_ASSERT(model.Tokenizer().Decode({token}, &token_text));
|
||||||
std::cout << token_text << std::flush;
|
std::cout << token_text << std::flush;
|
||||||
|
|
|
||||||
|
|
@ -80,7 +80,7 @@ class SimplifiedGemma {
|
||||||
++generated;
|
++generated;
|
||||||
if (generated < prompt_size) {
|
if (generated < prompt_size) {
|
||||||
// print feedback
|
// print feedback
|
||||||
} else if (token != gcpp::EOS_ID) {
|
} else if (!this->model_.GetModelConfig().IsEOS(token)) {
|
||||||
std::string token_text;
|
std::string token_text;
|
||||||
HWY_ASSERT(this->model_.Tokenizer().Decode({token}, &token_text));
|
HWY_ASSERT(this->model_.Tokenizer().Decode({token}, &token_text));
|
||||||
std::cout << token_text << std::flush;
|
std::cout << token_text << std::flush;
|
||||||
|
|
|
||||||
|
|
@ -195,6 +195,8 @@ static ModelConfig ConfigGemmaTiny() {
|
||||||
config.attention_window_sizes = FixedAttentionWindowSizes<3>(32);
|
config.attention_window_sizes = FixedAttentionWindowSizes<3>(32);
|
||||||
// This is required for optimize_test to pass.
|
// This is required for optimize_test to pass.
|
||||||
config.final_cap = 30.0f;
|
config.final_cap = 30.0f;
|
||||||
|
config.eos_id = 11;
|
||||||
|
config.secondary_eos_id = 11;
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -333,6 +335,8 @@ static ModelConfig ConfigBaseGemmaV3() {
|
||||||
ModelConfig config = ConfigNoSSM();
|
ModelConfig config = ConfigNoSSM();
|
||||||
config.att_cap = 0.0f;
|
config.att_cap = 0.0f;
|
||||||
config.final_cap = 0.0f;
|
config.final_cap = 0.0f;
|
||||||
|
config.eos_id = 1;
|
||||||
|
config.secondary_eos_id = 106;
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -294,6 +294,8 @@ struct ModelConfig : public IFields {
|
||||||
|
|
||||||
const char* Name() const override { return "ModelConfig"; }
|
const char* Name() const override { return "ModelConfig"; }
|
||||||
|
|
||||||
|
bool IsEOS(int id) const { return (id == eos_id || id == secondary_eos_id); }
|
||||||
|
|
||||||
void VisitFields(IFieldsVisitor& visitor) override {
|
void VisitFields(IFieldsVisitor& visitor) override {
|
||||||
visitor(model_family_version);
|
visitor(model_family_version);
|
||||||
visitor(model_name);
|
visitor(model_name);
|
||||||
|
|
@ -315,6 +317,8 @@ struct ModelConfig : public IFields {
|
||||||
visitor(norm_num_groups);
|
visitor(norm_num_groups);
|
||||||
visitor(vit_config);
|
visitor(vit_config);
|
||||||
visitor(pool_dim);
|
visitor(pool_dim);
|
||||||
|
visitor(eos_id);
|
||||||
|
visitor(secondary_eos_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Major version of the model family. It is used as a fallback to distinguish
|
// Major version of the model family. It is used as a fallback to distinguish
|
||||||
|
|
@ -341,6 +345,8 @@ struct ModelConfig : public IFields {
|
||||||
// Dimensions related to image processing.
|
// Dimensions related to image processing.
|
||||||
VitConfig vit_config;
|
VitConfig vit_config;
|
||||||
uint32_t pool_dim = 1; // used only for VitConfig copy
|
uint32_t pool_dim = 1; // used only for VitConfig copy
|
||||||
|
int eos_id = 1;
|
||||||
|
int secondary_eos_id = 1;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Returns the config for the given model.
|
// Returns the config for the given model.
|
||||||
|
|
|
||||||
|
|
@ -1299,8 +1299,9 @@ static size_t MaxQueryLength(const QueriesPromptTokens& queries_prompt) {
|
||||||
// Holds "is at end of stream" state for each query.
|
// Holds "is at end of stream" state for each query.
|
||||||
class TokenStreamer {
|
class TokenStreamer {
|
||||||
public:
|
public:
|
||||||
explicit TokenStreamer(const RuntimeConfig& runtime_config)
|
explicit TokenStreamer(const RuntimeConfig& runtime_config,
|
||||||
: runtime_config_(runtime_config) {}
|
const ModelConfig& model_config)
|
||||||
|
: runtime_config_(runtime_config), model_config_(model_config) {}
|
||||||
|
|
||||||
// Returns whether the query was already at, or has just reached, the end of
|
// Returns whether the query was already at, or has just reached, the end of
|
||||||
// the stream: either via token == eos_id, or StreamToken returning false.
|
// the stream: either via token == eos_id, or StreamToken returning false.
|
||||||
|
|
@ -1308,7 +1309,7 @@ class TokenStreamer {
|
||||||
if (HWY_UNLIKELY(is_eos_.Get(query_idx))) return true;
|
if (HWY_UNLIKELY(is_eos_.Get(query_idx))) return true;
|
||||||
|
|
||||||
if (!runtime_config_.StreamToken(query_idx, pos, token, prob) ||
|
if (!runtime_config_.StreamToken(query_idx, pos, token, prob) ||
|
||||||
token == runtime_config_.eos_id) {
|
model_config_.IsEOS(token)) {
|
||||||
is_eos_.Set(query_idx);
|
is_eos_.Set(query_idx);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
@ -1318,6 +1319,7 @@ class TokenStreamer {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const RuntimeConfig& runtime_config_;
|
const RuntimeConfig& runtime_config_;
|
||||||
|
const ModelConfig& model_config_;
|
||||||
hwy::BitSet4096<> is_eos_;
|
hwy::BitSet4096<> is_eos_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -1425,7 +1427,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
|
||||||
// Sanity check: prompts should not be empty, nor start with EOS.
|
// Sanity check: prompts should not be empty, nor start with EOS.
|
||||||
for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) {
|
for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) {
|
||||||
const PromptTokens& prompt = queries_prompt[query_idx];
|
const PromptTokens& prompt = queries_prompt[query_idx];
|
||||||
HWY_ASSERT(prompt.size() != 0 && prompt[0] != runtime_config.eos_id);
|
HWY_ASSERT(prompt.size() != 0 && !model.Config().IsEOS(prompt[0]));
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t num_queries = queries_prompt.size();
|
const size_t num_queries = queries_prompt.size();
|
||||||
|
|
@ -1469,7 +1471,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
|
||||||
std::vector<int> gen_tokens(num_queries);
|
std::vector<int> gen_tokens(num_queries);
|
||||||
|
|
||||||
// Stream the last prompt token from each query and fill gen_tokens.
|
// Stream the last prompt token from each query and fill gen_tokens.
|
||||||
TokenStreamer token_streamer(runtime_config);
|
TokenStreamer token_streamer(runtime_config, model.Config());
|
||||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||||
size_t last_token_pos_in_prompt =
|
size_t last_token_pos_in_prompt =
|
||||||
queries_mutable_pos[query_idx] - queries_pos_in[query_idx];
|
queries_mutable_pos[query_idx] - queries_pos_in[query_idx];
|
||||||
|
|
|
||||||
|
|
@ -118,7 +118,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
||||||
// callback function invoked for each generated token.
|
// callback function invoked for each generated token.
|
||||||
auto stream_token = [&](int token, float) {
|
auto stream_token = [&](int token, float) {
|
||||||
++abs_pos;
|
++abs_pos;
|
||||||
if (token == EOS_ID) {
|
if (model.GetModelConfig().IsEOS(token)) {
|
||||||
if (app.verbosity >= 2) {
|
if (app.verbosity >= 2) {
|
||||||
std::cout << "\n[ End ]\n";
|
std::cout << "\n[ End ]\n";
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@ namespace gcpp {
|
||||||
|
|
||||||
// The tokenizer's end of sentence and beginning of sentence token ids.
|
// The tokenizer's end of sentence and beginning of sentence token ids.
|
||||||
constexpr int EOS_ID = 1;
|
constexpr int EOS_ID = 1;
|
||||||
|
constexpr int SECONDARY_EOS_ID = 106; // for Gemma 3
|
||||||
constexpr int BOS_ID = 2;
|
constexpr int BOS_ID = 2;
|
||||||
|
|
||||||
class GemmaTokenizer {
|
class GemmaTokenizer {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue