Refactor `WrapAndTokenize` to work properly with Gemma3

This commit is contained in:
RangerUFO 2025-03-26 18:19:05 +08:00
parent 76a81ac2d6
commit ca4ee2b63f
11 changed files with 139 additions and 96 deletions

View File

@ -69,7 +69,7 @@ class GemmaEnv {
} }
std::vector<int> WrapAndTokenize(std::string& input) const { std::vector<int> WrapAndTokenize(std::string& input) const {
return gcpp::WrapAndTokenize(model_->Tokenizer(), model_->Info(), 0, input); return gcpp::WrapAndTokenize(model_->Tokenizer(), model_->ChatTemplate(), model_->Info(), 0, input);
} }
std::string StringFromTokens(const std::vector<int>& tokens) const { std::string StringFromTokens(const std::vector<int>& tokens) const {

View File

@ -178,16 +178,18 @@ TEST_F(GemmaTest, Multiturn) {
TimingInfo timing_info{.verbosity = 0}; TimingInfo timing_info{.verbosity = 0};
// First "say" something slightly unusual. // First "say" something slightly unusual.
std::string mutable_prompt = "I have a car and its color is turquoise."; std::string mutable_prompt = "I have a car and its color is turquoise.";
std::vector<int> tokens = WrapAndTokenize(model->Tokenizer(), model->Info(), std::vector<int> tokens = WrapAndTokenize(model->Tokenizer(),
abs_pos, mutable_prompt); model->ChatTemplate(),
model->Info(), abs_pos,
mutable_prompt);
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(), model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
timing_info); timing_info);
// Note: we do not rewind any <end_of_turn> tokens here. If the model // Note: we do not rewind any <end_of_turn> tokens here. If the model
// produced one and WrapAndTokenize() inserts another one, it will just be // produced one and WrapAndTokenize() inserts another one, it will just be
// duplicated. // duplicated.
mutable_prompt = "Please repeat all prior statements."; mutable_prompt = "Please repeat all prior statements.";
tokens = WrapAndTokenize(model->Tokenizer(), model->Info(), abs_pos, tokens = WrapAndTokenize(model->Tokenizer(), model->ChatTemplate(),
mutable_prompt); model->Info(), abs_pos, mutable_prompt);
// Reset the `response` string here, then check that the model actually has // Reset the `response` string here, then check that the model actually has
// access to the previous turn by asking to reproduce. // access to the previous turn by asking to reproduce.
response.clear(); response.clear();

View File

@ -75,7 +75,8 @@ int main(int argc, char** argv) {
// Tokenize instructions. // Tokenize instructions.
std::string prompt = "Write a greeting to the world."; std::string prompt = "Write a greeting to the world.";
const std::vector<int> tokens = gcpp::WrapAndTokenize( const std::vector<int> tokens = gcpp::WrapAndTokenize(
model.Tokenizer(), loader.Info(), generated, prompt); model.Tokenizer(), model.ChatTemplate(), loader.Info(),
generated, prompt);
const size_t prompt_size = tokens.size(); const size_t prompt_size = tokens.size();
// This callback function gets invoked every time a token is generated // This callback function gets invoked every time a token is generated

View File

@ -72,7 +72,8 @@ class SimplifiedGemma {
size_t generated = 0; size_t generated = 0;
const std::vector<int> tokens = gcpp::WrapAndTokenize( const std::vector<int> tokens = gcpp::WrapAndTokenize(
model_.Tokenizer(), loader_.Info(), generated, prompt); model_.Tokenizer(), model_.ChatTemplate(), loader_.Info(),
generated, prompt);
const size_t prompt_size = tokens.size(); const size_t prompt_size = tokens.size();
// This callback function gets invoked every time a token is generated // This callback function gets invoked every time a token is generated

View File

@ -148,18 +148,6 @@ const char* ParseType(const std::string& type_string, Type& type) {
return kErrorMessageBuffer.c_str(); return kErrorMessageBuffer.c_str();
} }
void Wrap(const ModelInfo& info, size_t pos, std::string& prompt) {
// Instruction-tuned models are trained to expect control tokens.
if (info.wrapping == PromptWrapping::GEMMA_IT) {
// Prepend "<end_of_turn>" if this is a multi-turn dialogue continuation.
const std::string start = (pos == 0)
? "<start_of_turn>user\n"
: "<end_of_turn>\n<start_of_turn>user\n";
prompt = start + prompt + "<end_of_turn>\n<start_of_turn>model\n";
}
}
float EmbeddingScaling(size_t model_dim) { float EmbeddingScaling(size_t model_dim) {
// Round to bf16 to match Gemma's Embedder, which casts before mul. // Round to bf16 to match Gemma's Embedder, which casts before mul.
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>( return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(

View File

@ -43,9 +43,6 @@ const char* ParseType(const std::string& type_string, Type& type);
const char* ModelString(Model model, PromptWrapping wrapping); const char* ModelString(Model model, PromptWrapping wrapping);
const char* StringFromType(Type type); const char* StringFromType(Type type);
// Wraps the given prompt using the expected control tokens for IT models.
void Wrap(const ModelInfo& info, size_t pos, std::string& prompt);
// Returns the scale value to use for the embedding (basically sqrt model_dim). // Returns the scale value to use for the embedding (basically sqrt model_dim).
float EmbeddingScaling(size_t model_dim); float EmbeddingScaling(size_t model_dim);

View File

@ -40,7 +40,7 @@ namespace gcpp {
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
const ModelInfo& info, MatMulEnv& env) const ModelInfo& info, MatMulEnv& env)
: env_(env), tokenizer_(tokenizer_path) { : env_(env), tokenizer_(tokenizer_path), chat_template_(tokenizer_) {
model_.Load(weights, info.model, info.weight, info.wrapping, model_.Load(weights, info.model, info.weight, info.wrapping,
env_.parallel.Pools().Pool(0), env_.parallel.Pools().Pool(0),
/*tokenizer_proto=*/nullptr); /*tokenizer_proto=*/nullptr);
@ -51,10 +51,11 @@ Gemma::Gemma(const Path& weights, MatMulEnv& env) : env_(env) {
model_.Load(weights, Model::UNKNOWN, Type::kUnknown, PromptWrapping::GEMMA_IT, model_.Load(weights, Model::UNKNOWN, Type::kUnknown, PromptWrapping::GEMMA_IT,
env_.parallel.Pools().Pool(0), &tokenizer_proto); env_.parallel.Pools().Pool(0), &tokenizer_proto);
tokenizer_.Deserialize(tokenizer_proto); tokenizer_.Deserialize(tokenizer_proto);
chat_template_.Init(tokenizer_);
} }
Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, MatMulEnv& env) Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, MatMulEnv& env)
: env_(env), tokenizer_(std::move(tokenizer)) { : env_(env), tokenizer_(std::move(tokenizer)), chat_template_(tokenizer_) {
HWY_ASSERT(info.weight == Type::kF32); HWY_ASSERT(info.weight == Type::kF32);
model_.Allocate(info.model, info.weight, env_.parallel.Pools().Pool(0)); model_.Allocate(info.model, info.weight, env_.parallel.Pools().Pool(0));
} }

View File

@ -213,6 +213,7 @@ class Gemma {
.weight = model_.Config().weight}); .weight = model_.Config().weight});
} }
const GemmaTokenizer& Tokenizer() const { return tokenizer_; } const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
const GemmaChatTemplate& ChatTemplate() const { return chat_template_; }
const ModelWeightsStorage& Weights() const { return model_; } const ModelWeightsStorage& Weights() const { return model_; }
ModelWeightsStorage& MutableWeights() { return model_; } ModelWeightsStorage& MutableWeights() { return model_; }
void Save(const Path& weights, hwy::ThreadPool& pool) { void Save(const Path& weights, hwy::ThreadPool& pool) {
@ -256,6 +257,7 @@ class Gemma {
MatMulEnv& env_; MatMulEnv& env_;
GemmaTokenizer tokenizer_; GemmaTokenizer tokenizer_;
GemmaChatTemplate chat_template_;
// Type-erased so that this can be defined in the header. // Type-erased so that this can be defined in the header.
ModelWeightsStorage model_; ModelWeightsStorage model_;
}; };

View File

@ -162,16 +162,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
continue; continue;
} }
// Wrap, tokenize and maybe log prompt tokens. std::vector<int> prompt;
std::vector<int> prompt = WrapAndTokenize(
model.Tokenizer(), model.Info(), abs_pos, prompt_string);
prompt_size = prompt.size();
if constexpr (kVerboseLogTokens) {
for (int i = 0; i < prompt_size; ++i) {
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
}
}
// Set up runtime config. // Set up runtime config.
TimingInfo timing_info = {.verbosity = app.verbosity}; TimingInfo timing_info = {.verbosity = app.verbosity};
RuntimeConfig runtime_config = {.gen = &gen, RuntimeConfig runtime_config = {.gen = &gen,
@ -182,22 +173,26 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
args.CopyTo(runtime_config); args.CopyTo(runtime_config);
size_t prefix_end = 0; size_t prefix_end = 0;
if (have_image) { if (have_image) {
prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(),
model.Info(), abs_pos, prompt_string,
image_tokens.BatchSize());
runtime_config.image_tokens = &image_tokens; runtime_config.image_tokens = &image_tokens;
if (model.Info().wrapping == PromptWrapping::PALIGEMMA) {
prompt.insert(prompt.begin(), image_tokens.BatchSize(), 0);
} else if (model.Info().wrapping == PromptWrapping::GEMMA_VLM) {
size_t seq_len = model.GetModelConfig().vit_config.seq_len;
size_t pool_dim = model.GetModelConfig().vit_config.pool_dim;
prompt =
WrapVLM(model.Tokenizer(), model.Info(), abs_pos, prompt,
image_tokens.BatchSize(), seq_len / (pool_dim * pool_dim));
}
prompt_size = prompt.size(); prompt_size = prompt.size();
// The end of the prefix for prefix-LM style attention in Paligemma. // The end of the prefix for prefix-LM style attention in Paligemma.
// See Figure 2 of https://arxiv.org/abs/2407.07726. // See Figure 2 of https://arxiv.org/abs/2407.07726.
prefix_end = prompt_size; prefix_end = prompt_size;
// We need to look at all the tokens for the prefix. // We need to look at all the tokens for the prefix.
runtime_config.prefill_tbatch_size = prompt_size; runtime_config.prefill_tbatch_size = prompt_size;
} else {
prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(),
model.Info(), abs_pos, prompt_string);
prompt_size = prompt.size();
}
if constexpr (kVerboseLogTokens) {
for (int i = 0; i < prompt_size; ++i) {
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
}
} }
// Generate until EOS or max_generated_tokens. // Generate until EOS or max_generated_tokens.

View File

@ -114,57 +114,96 @@ bool GemmaTokenizer::Decode(const std::vector<int>& ids,
return impl_->Decode(ids, detokenized); return impl_->Decode(ids, detokenized);
} }
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer, GemmaChatTemplate::GemmaChatTemplate(const GemmaTokenizer& tokenizer) {
const ModelInfo& info, size_t pos, Init(tokenizer);
std::string& prompt) {
Wrap(info, pos, prompt);
std::vector<int> tokens;
HWY_ASSERT(tokenizer.Encode(prompt, &tokens));
// Both pre-trained and instruction-tuned require BOS as first token.
if (pos == 0) {
tokens.insert(tokens.begin(), BOS_ID);
}
// PaliGemma separator. The SEP token "\n" is always tokenized separately.
if (info.wrapping == PromptWrapping::PALIGEMMA
// || info.wrapping == PromptWrapping::GEMMA_VLM
) {
std::vector<int> sep_tokens;
HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens));
tokens.insert(tokens.end(), sep_tokens.begin(), sep_tokens.end());
}
return tokens;
} }
std::vector<int> WrapVLM(const GemmaTokenizer& tokenizer, const ModelInfo& info, void GemmaChatTemplate::Init(const GemmaTokenizer& tokenizer) {
size_t pos, std::vector<int>& tokens, sot_user_.reserve(3);
size_t image_batch_size, size_t max_image_batch_size) { HWY_ASSERT(tokenizer.Encode("<start_of_turn>user\n", &sot_user_));
HWY_ASSERT(info.wrapping == PromptWrapping::GEMMA_VLM); sot_model_.reserve(3);
size_t num_images = hwy::DivCeil(image_batch_size, max_image_batch_size); HWY_ASSERT(tokenizer.Encode("<start_of_turn>model\n", &sot_model_));
eot_.reserve(2);
HWY_ASSERT(tokenizer.Encode("<end_of_turn>\n", &eot_));
}
std::vector<int> sep_tokens; std::vector<int> GemmaChatTemplate::Apply(size_t pos,
HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens)); const std::vector<int>& ids) const {
HWY_ASSERT_M(!sot_user_.empty() && !sot_model_.empty() && !eot_.empty(),
std::string begin_image_prompt = "\n\n<start_of_image>"; "GemmaChatTemplate has not been initialized.");
std::vector<int> begin_image_tokens = std::vector<int> out;
WrapAndTokenize(tokenizer, info, pos, begin_image_prompt); out.reserve(eot_.size() +
sot_user_.size() +
std::string end_image_prompt = "<end_of_image>\n\n"; ids.size() +
std::vector<int> end_image_tokens = eot_.size() +
WrapAndTokenize(tokenizer, info, pos, end_image_prompt); sot_model_.size());
if (pos > 0) {
for (size_t i = 0; i < num_images; ++i) { out.insert(out.cend(), eot_.cbegin(), eot_.cend());
tokens.insert(tokens.begin(), begin_image_tokens.begin(), } else {
begin_image_tokens.end()); out.push_back(BOS_ID);
tokens.insert(tokens.begin() + begin_image_tokens.size(), image_batch_size,
-2);
tokens.insert(tokens.begin() + begin_image_tokens.size() + image_batch_size,
end_image_tokens.begin(), end_image_tokens.end());
} }
out.insert(out.cend(), sot_user_.cbegin(), sot_user_.cend());
out.insert(out.cend(), ids.cbegin(), ids.cend());
out.insert(out.cend(), eot_.cbegin(), eot_.cend());
out.insert(out.cend(), sot_model_.cbegin(), sot_model_.cend());
return out;
}
return tokens; std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
const GemmaChatTemplate& chat_template,
const ModelInfo& info, size_t pos,
const std::string& prompt) {
std::vector<int> tokens;
HWY_ASSERT(tokenizer.Encode(prompt, &tokens));
switch (info.wrapping) {
case PromptWrapping::GEMMA_IT:
case PromptWrapping::GEMMA_VLM:
return chat_template.Apply(pos, tokens);
default:
if (pos == 0) {
tokens.insert(tokens.cbegin(), BOS_ID);
}
return tokens;
}
}
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
const GemmaChatTemplate& chat_template,
const ModelInfo& info, size_t pos,
const std::string& prompt,
size_t image_batch_size) {
std::vector<int> text_part;
HWY_ASSERT(tokenizer.Encode(prompt, &text_part));
std::vector<int> tokens;
switch (info.wrapping) {
case PromptWrapping::PALIGEMMA: {
std::vector<int> sep;
HWY_ASSERT(tokenizer.Encode("\n", &sep));
tokens.reserve(image_batch_size + 1 + text_part.size() + sep.size());
tokens.resize(image_batch_size, 0);
HWY_ASSERT(pos == 0);
tokens.push_back(BOS_ID);
tokens.insert(tokens.cend(), text_part.cbegin(), text_part.cend());
tokens.insert(tokens.cend(), sep.cbegin(), sep.cend());
return tokens;
}
case PromptWrapping::GEMMA_VLM: {
std::vector<int> soi;
soi.reserve(2);
HWY_ASSERT(tokenizer.Encode("\n\n<start_of_image>", &soi));
std::vector<int> eoi;
eoi.reserve(2);
HWY_ASSERT(tokenizer.Encode("<end_of_image>\n\n", &eoi));
tokens.reserve(text_part.size() + soi.size() + image_batch_size + eoi.size());
tokens.insert(tokens.cend(), text_part.cbegin(), text_part.cend());
tokens.insert(tokens.cend(), soi.cbegin(), soi.cend());
tokens.insert(tokens.cend(), image_batch_size, -2);
tokens.insert(tokens.cend(), eoi.cbegin(), eoi.cend());
return chat_template.Apply(pos, tokens);
}
default:
HWY_ASSERT_M(false, "Current variant does not support vision prompt.");
}
} }
} // namespace gcpp } // namespace gcpp

View File

@ -54,13 +54,30 @@ class GemmaTokenizer {
std::unique_ptr<Impl> impl_; std::unique_ptr<Impl> impl_;
}; };
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer, class GemmaChatTemplate {
const ModelInfo& info, size_t pos, public:
std::string& prompt); GemmaChatTemplate() = default;
explicit GemmaChatTemplate(const GemmaTokenizer& tokenizer);
std::vector<int> WrapVLM(const GemmaTokenizer& tokenizer, const ModelInfo& info, void Init(const GemmaTokenizer& tokenizer);
size_t pos, std::vector<int>& tokens, std::vector<int> Apply(size_t pos, const std::vector<int>& ids) const;
size_t image_batch_size, size_t max_image_batch_size);
private:
std::vector<int> sot_user_;
std::vector<int> sot_model_;
std::vector<int> eot_;
};
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
const GemmaChatTemplate& chat_template,
const ModelInfo& info, size_t pos,
const std::string& prompt);
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
const GemmaChatTemplate& chat_template,
const ModelInfo& info, size_t pos,
const std::string& prompt,
size_t image_batch_size);
} // namespace gcpp } // namespace gcpp