Merge pull request #529 from ufownl:refactor/wrap_and_tokenize

PiperOrigin-RevId: 745174371
This commit is contained in:
Copybara-Service 2025-04-08 09:22:26 -07:00
commit bef91a3f03
11 changed files with 176 additions and 83 deletions

View File

@ -69,7 +69,8 @@ class GemmaEnv {
} }
std::vector<int> WrapAndTokenize(std::string& input) const { std::vector<int> WrapAndTokenize(std::string& input) const {
return gcpp::WrapAndTokenize(model_->Tokenizer(), model_->Info(), 0, input); return gcpp::WrapAndTokenize(model_->Tokenizer(), model_->ChatTemplate(),
model_->Info(), 0, input);
} }
std::string StringFromTokens(const std::vector<int>& tokens) const { std::string StringFromTokens(const std::vector<int>& tokens) const {

View File

@ -178,22 +178,25 @@ TEST_F(GemmaTest, Multiturn) {
TimingInfo timing_info{.verbosity = 0}; TimingInfo timing_info{.verbosity = 0};
// First "say" something slightly unusual. // First "say" something slightly unusual.
std::string mutable_prompt = "I have a car and its color is turquoise."; std::string mutable_prompt = "I have a car and its color is turquoise.";
std::vector<int> tokens = WrapAndTokenize(model->Tokenizer(), model->Info(), std::vector<int> tokens =
abs_pos, mutable_prompt); WrapAndTokenize(model->Tokenizer(), model->ChatTemplate(), model->Info(),
abs_pos, mutable_prompt);
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(), model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
timing_info); timing_info);
// Note: we do not rewind any <end_of_turn> tokens here. If the model // Note: we do not rewind any <end_of_turn> tokens here. If the model
// produced one and WrapAndTokenize() inserts another one, it will just be // produced one and WrapAndTokenize() inserts another one, it will just be
// duplicated. // duplicated.
mutable_prompt = "Please repeat all prior statements."; mutable_prompt = "Please repeat all prior statements.";
tokens = WrapAndTokenize(model->Tokenizer(), model->Info(), abs_pos, tokens = WrapAndTokenize(model->Tokenizer(), model->ChatTemplate(),
mutable_prompt); model->Info(), abs_pos, mutable_prompt);
// Reset the `response` string here, then check that the model actually has // Reset the `response` string here, then check that the model actually has
// access to the previous turn by asking to reproduce. // access to the previous turn by asking to reproduce.
response.clear(); response.clear();
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(), model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
timing_info); timing_info);
fprintf(stderr, "decoded: %s\n", response.c_str()); fprintf(stderr, "decoded: '%s'\n", response.c_str());
bool remembered_turquoise = bool remembered_turquoise =
response.find("turquoise") != std::string::npos; // NOLINT response.find("turquoise") != std::string::npos; // NOLINT
bool remembered_car = response.find("car") != std::string::npos; // NOLINT bool remembered_car = response.find("car") != std::string::npos; // NOLINT

View File

@ -74,8 +74,9 @@ int main(int argc, char** argv) {
// Tokenize instructions. // Tokenize instructions.
std::string prompt = "Write a greeting to the world."; std::string prompt = "Write a greeting to the world.";
const std::vector<int> tokens = gcpp::WrapAndTokenize( const std::vector<int> tokens =
model.Tokenizer(), loader.Info(), generated, prompt); gcpp::WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(),
loader.Info(), generated, prompt);
const size_t prompt_size = tokens.size(); const size_t prompt_size = tokens.size();
// This callback function gets invoked every time a token is generated // This callback function gets invoked every time a token is generated

View File

@ -72,7 +72,8 @@ class SimplifiedGemma {
size_t generated = 0; size_t generated = 0;
const std::vector<int> tokens = gcpp::WrapAndTokenize( const std::vector<int> tokens = gcpp::WrapAndTokenize(
model_.Tokenizer(), loader_.Info(), generated, prompt); model_.Tokenizer(), model_.ChatTemplate(), loader_.Info(),
generated, prompt);
const size_t prompt_size = tokens.size(); const size_t prompt_size = tokens.size();
// This callback function gets invoked every time a token is generated // This callback function gets invoked every time a token is generated

View File

@ -25,6 +25,8 @@
#include <vector> #include <vector>
#include "util/basics.h" // BF16 #include "util/basics.h" // BF16
// TODO: change include when PromptWrapping is moved.
#include "compression/shared.h" // PromptWrapping
#include "hwy/base.h" #include "hwy/base.h"
namespace gcpp { namespace gcpp {
@ -79,7 +81,7 @@ constexpr PromptWrapping kPromptWrapping[] = {
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 3B 224/448 PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 3B 224/448
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 10B 224/448 PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 10B 224/448
PromptWrapping::GEMMA_VLM, // Gemma3 4B PromptWrapping::GEMMA_VLM, // Gemma3 4B
PromptWrapping::GEMMA_PT, // Gemma3 1B PromptWrapping::GEMMA_IT, // Gemma3 1B
PromptWrapping::GEMMA_VLM, // Gemma3 12B PromptWrapping::GEMMA_VLM, // Gemma3 12B
PromptWrapping::GEMMA_VLM, // Gemma3 27B PromptWrapping::GEMMA_VLM, // Gemma3 27B
}; };

View File

@ -44,6 +44,7 @@ const char* ModelString(Model model, PromptWrapping wrapping);
const char* StringFromType(Type type); const char* StringFromType(Type type);
// Wraps the given prompt using the expected control tokens for IT models. // Wraps the given prompt using the expected control tokens for IT models.
// `GemmaChatTemplate` is preferred if a tokenized return value is fine.
void Wrap(const ModelInfo& info, size_t pos, std::string& prompt); void Wrap(const ModelInfo& info, size_t pos, std::string& prompt);
// Returns the scale value to use for the embedding (basically sqrt model_dim). // Returns the scale value to use for the embedding (basically sqrt model_dim).

View File

@ -44,6 +44,7 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
model_.Load(weights, info.model, info.weight, info.wrapping, model_.Load(weights, info.model, info.weight, info.wrapping,
env_.parallel.Pools().Pool(0), env_.parallel.Pools().Pool(0),
/*tokenizer_proto=*/nullptr); /*tokenizer_proto=*/nullptr);
chat_template_.Init(tokenizer_, model_.Config().model);
} }
Gemma::Gemma(const Path& weights, MatMulEnv& env) : env_(env) { Gemma::Gemma(const Path& weights, MatMulEnv& env) : env_(env) {
@ -51,10 +52,13 @@ Gemma::Gemma(const Path& weights, MatMulEnv& env) : env_(env) {
model_.Load(weights, Model::UNKNOWN, Type::kUnknown, PromptWrapping::GEMMA_IT, model_.Load(weights, Model::UNKNOWN, Type::kUnknown, PromptWrapping::GEMMA_IT,
env_.parallel.Pools().Pool(0), &tokenizer_proto); env_.parallel.Pools().Pool(0), &tokenizer_proto);
tokenizer_.Deserialize(tokenizer_proto); tokenizer_.Deserialize(tokenizer_proto);
chat_template_.Init(tokenizer_, model_.Config().model);
} }
Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, MatMulEnv& env) Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, MatMulEnv& env)
: env_(env), tokenizer_(std::move(tokenizer)) { : env_(env),
tokenizer_(std::move(tokenizer)),
chat_template_(tokenizer_, info.model) {
HWY_ASSERT(info.weight == Type::kF32); HWY_ASSERT(info.weight == Type::kF32);
model_.Allocate(info.model, info.weight, env_.parallel.Pools().Pool(0)); model_.Allocate(info.model, info.weight, env_.parallel.Pools().Pool(0));
} }

View File

@ -213,6 +213,7 @@ class Gemma {
.weight = model_.Config().weight}); .weight = model_.Config().weight});
} }
const GemmaTokenizer& Tokenizer() const { return tokenizer_; } const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
const GemmaChatTemplate& ChatTemplate() const { return chat_template_; }
const ModelWeightsStorage& Weights() const { return model_; } const ModelWeightsStorage& Weights() const { return model_; }
ModelWeightsStorage& MutableWeights() { return model_; } ModelWeightsStorage& MutableWeights() { return model_; }
void Save(const Path& weights, hwy::ThreadPool& pool) { void Save(const Path& weights, hwy::ThreadPool& pool) {
@ -256,6 +257,7 @@ class Gemma {
MatMulEnv& env_; MatMulEnv& env_;
GemmaTokenizer tokenizer_; GemmaTokenizer tokenizer_;
GemmaChatTemplate chat_template_;
// Type-erased so that this can be defined in the header. // Type-erased so that this can be defined in the header.
ModelWeightsStorage model_; ModelWeightsStorage model_;
}; };

View File

@ -162,16 +162,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
continue; continue;
} }
// Wrap, tokenize and maybe log prompt tokens.
std::vector<int> prompt = WrapAndTokenize(
model.Tokenizer(), model.Info(), abs_pos, prompt_string);
prompt_size = prompt.size();
if constexpr (kVerboseLogTokens) {
for (int i = 0; i < prompt_size; ++i) {
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
}
}
// Set up runtime config. // Set up runtime config.
TimingInfo timing_info = {.verbosity = app.verbosity}; TimingInfo timing_info = {.verbosity = app.verbosity};
RuntimeConfig runtime_config = {.gen = &gen, RuntimeConfig runtime_config = {.gen = &gen,
@ -181,23 +171,29 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
.use_spinning = app.spin}; .use_spinning = app.spin};
args.CopyTo(runtime_config); args.CopyTo(runtime_config);
size_t prefix_end = 0; size_t prefix_end = 0;
std::vector<int> prompt;
if (have_image) { if (have_image) {
prompt =
WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(),
abs_pos, prompt_string, image_tokens.BatchSize());
runtime_config.image_tokens = &image_tokens; runtime_config.image_tokens = &image_tokens;
if (model.Info().wrapping == PromptWrapping::PALIGEMMA) {
prompt.insert(prompt.begin(), image_tokens.BatchSize(), 0);
} else if (model.Info().wrapping == PromptWrapping::GEMMA_VLM) {
size_t seq_len = model.GetModelConfig().vit_config.seq_len;
size_t pool_dim = model.GetModelConfig().vit_config.pool_dim;
prompt =
WrapVLM(model.Tokenizer(), model.Info(), abs_pos, prompt,
image_tokens.BatchSize(), seq_len / (pool_dim * pool_dim));
}
prompt_size = prompt.size(); prompt_size = prompt.size();
// The end of the prefix for prefix-LM style attention in Paligemma. // The end of the prefix for prefix-LM style attention in Paligemma.
// See Figure 2 of https://arxiv.org/abs/2407.07726. // See Figure 2 of https://arxiv.org/abs/2407.07726.
prefix_end = prompt_size; prefix_end = prompt_size;
// We need to look at all the tokens for the prefix. // We need to look at all the tokens for the prefix.
runtime_config.prefill_tbatch_size = prompt_size; runtime_config.prefill_tbatch_size = prompt_size;
} else {
prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(),
model.Info(), abs_pos, prompt_string);
prompt_size = prompt.size();
}
if constexpr (kVerboseLogTokens) {
for (int i = 0; i < prompt_size; ++i) {
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
}
} }
// Generate until EOS or max_generated_tokens. // Generate until EOS or max_generated_tokens.

View File

@ -100,71 +100,123 @@ void GemmaTokenizer::Deserialize(const std::string& tokenizer_proto) {
bool GemmaTokenizer::Encode(const std::string& input, bool GemmaTokenizer::Encode(const std::string& input,
std::vector<std::string>* pieces) const { std::vector<std::string>* pieces) const {
return impl_->Encode(input, pieces); return impl_ && impl_->Encode(input, pieces);
} }
bool GemmaTokenizer::Encode(const std::string& input, bool GemmaTokenizer::Encode(const std::string& input,
std::vector<int>* ids) const { std::vector<int>* ids) const {
return impl_->Encode(input, ids); return impl_ && impl_->Encode(input, ids);
} }
// Given a sequence of ids, decodes it into a detokenized output. // Given a sequence of ids, decodes it into a detokenized output.
bool GemmaTokenizer::Decode(const std::vector<int>& ids, bool GemmaTokenizer::Decode(const std::vector<int>& ids,
std::string* detokenized) const { std::string* detokenized) const {
return impl_->Decode(ids, detokenized); return impl_ && impl_->Decode(ids, detokenized);
} }
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer, bool GemmaChatTemplate::Init(const GemmaTokenizer& tokenizer, Model model) {
const ModelInfo& info, size_t pos, sot_user_.reserve(3);
std::string& prompt) { if (!tokenizer.Encode("<start_of_turn>user\n", &sot_user_)) return false;
Wrap(info, pos, prompt); sot_model_.reserve(3);
HWY_ASSERT(tokenizer.Encode("<start_of_turn>model\n", &sot_model_));
eot_.reserve(2);
HWY_ASSERT(tokenizer.Encode("<end_of_turn>\n", &eot_));
HWY_ASSERT(tokenizer.Encode("\n", &pali_sep_));
vlm_soi_.reserve(2);
HWY_ASSERT(tokenizer.Encode("\n\n<start_of_image>", &vlm_soi_));
vlm_eoi_.reserve(2);
HWY_ASSERT(tokenizer.Encode("<end_of_image>\n\n", &vlm_eoi_));
return true;
}
std::vector<int> GemmaChatTemplate::Apply(size_t pos,
const std::vector<int>& ids) const {
HWY_ASSERT_M(!sot_user_.empty() && !sot_model_.empty() && !eot_.empty(),
"GemmaChatTemplate has not been initialized.");
std::vector<int> out;
out.reserve(eot_.size() + sot_user_.size() + ids.size() + eot_.size() +
sot_model_.size());
// Start with BOS, or prepend end_of_turn if this is a continuation.
if (pos == 0) {
out.push_back(BOS_ID);
} else {
out.insert(out.cend(), eot_.cbegin(), eot_.cend());
}
// Start of user turn, user prompt, end of turn; then start of model turn.
out.insert(out.cend(), sot_user_.cbegin(), sot_user_.cend());
out.insert(out.cend(), ids.cbegin(), ids.cend());
out.insert(out.cend(), eot_.cbegin(), eot_.cend());
out.insert(out.cend(), sot_model_.cbegin(), sot_model_.cend());
return out;
}
std::vector<int> GemmaChatTemplate::WrapPali(const std::vector<int>& text_part,
size_t image_batch_size) const {
HWY_ASSERT_M(!pali_sep_.empty(),
"GemmaChatTemplate has not been initialized.");
std::vector<int> out;
out.reserve(image_batch_size + 1 + text_part.size() + pali_sep_.size());
out.resize(image_batch_size, 0);
out.push_back(BOS_ID);
out.insert(out.cend(), text_part.cbegin(), text_part.cend());
out.insert(out.cend(), pali_sep_.cbegin(), pali_sep_.cend());
return out;
}
std::vector<int> GemmaChatTemplate::WrapVLM(const std::vector<int>& text_part,
size_t image_batch_size) const {
HWY_ASSERT_M(!vlm_soi_.empty() && !vlm_eoi_.empty(),
"GemmaChatTemplate has not been initialized.");
std::vector<int> out;
out.reserve(text_part.size() + vlm_soi_.size() + image_batch_size +
vlm_eoi_.size());
out.insert(out.cend(), text_part.cbegin(), text_part.cend());
out.insert(out.cend(), vlm_soi_.cbegin(), vlm_soi_.cend());
out.insert(out.cend(), image_batch_size, -2);
out.insert(out.cend(), vlm_eoi_.cbegin(), vlm_eoi_.cend());
return out;
}
// Text
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
const GemmaChatTemplate& chat_template,
const ModelInfo& info, size_t pos,
const std::string& prompt) {
std::vector<int> tokens; std::vector<int> tokens;
HWY_ASSERT(tokenizer.Encode(prompt, &tokens)); HWY_ASSERT(tokenizer.Encode(prompt, &tokens));
// Both pre-trained and instruction-tuned require BOS as first token.
if (pos == 0) {
tokens.insert(tokens.begin(), BOS_ID);
}
// PaliGemma separator. The SEP token "\n" is always tokenized separately. switch (info.wrapping) {
if (info.wrapping == PromptWrapping::PALIGEMMA case PromptWrapping::GEMMA_IT:
// || info.wrapping == PromptWrapping::GEMMA_VLM case PromptWrapping::GEMMA_VLM:
) { return chat_template.Apply(pos, tokens);
std::vector<int> sep_tokens; default:
HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens)); if (pos == 0) {
tokens.insert(tokens.end(), sep_tokens.begin(), sep_tokens.end()); tokens.insert(tokens.cbegin(), BOS_ID);
}
return tokens;
} }
return tokens;
} }
std::vector<int> WrapVLM(const GemmaTokenizer& tokenizer, const ModelInfo& info, // Vision
size_t pos, std::vector<int>& tokens, std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
size_t image_batch_size, size_t max_image_batch_size) { const GemmaChatTemplate& chat_template,
HWY_ASSERT(info.wrapping == PromptWrapping::GEMMA_VLM); const ModelInfo& info, size_t pos,
size_t num_images = hwy::DivCeil(image_batch_size, max_image_batch_size); const std::string& prompt,
size_t image_batch_size) {
std::vector<int> sep_tokens; std::vector<int> text_part;
HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens)); HWY_ASSERT(tokenizer.Encode(prompt, &text_part));
switch (info.wrapping) {
std::string begin_image_prompt = "\n\n<start_of_image>"; case PromptWrapping::PALIGEMMA:
std::vector<int> begin_image_tokens = HWY_ASSERT(pos == 0);
WrapAndTokenize(tokenizer, info, pos, begin_image_prompt); return chat_template.WrapPali(text_part, image_batch_size);
case PromptWrapping::GEMMA_VLM:
std::string end_image_prompt = "<end_of_image>\n\n"; return chat_template.Apply(
std::vector<int> end_image_tokens = pos, chat_template.WrapVLM(text_part, image_batch_size));
WrapAndTokenize(tokenizer, info, pos, end_image_prompt); default:
HWY_ASSERT_M(false, "Current variant does not support vision prompt.");
for (size_t i = 0; i < num_images; ++i) {
tokens.insert(tokens.begin(), begin_image_tokens.begin(),
begin_image_tokens.end());
tokens.insert(tokens.begin() + begin_image_tokens.size(), image_batch_size,
-2);
tokens.insert(tokens.begin() + begin_image_tokens.size() + image_batch_size,
end_image_tokens.begin(), end_image_tokens.end());
} }
return tokens;
} }
} // namespace gcpp } // namespace gcpp

View File

@ -54,13 +54,43 @@ class GemmaTokenizer {
std::unique_ptr<Impl> impl_; std::unique_ptr<Impl> impl_;
}; };
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer, class GemmaChatTemplate {
const ModelInfo& info, size_t pos, public:
std::string& prompt); GemmaChatTemplate() = default;
explicit GemmaChatTemplate(const GemmaTokenizer& tokenizer, Model model) {
(void)Init(tokenizer, model);
}
std::vector<int> WrapVLM(const GemmaTokenizer& tokenizer, const ModelInfo& info, // Returns false if the tokenizer is not available (as in optimize_test.cc).
size_t pos, std::vector<int>& tokens, bool Init(const GemmaTokenizer& tokenizer, Model model);
size_t image_batch_size, size_t max_image_batch_size);
// Given prompt tokens, this returns the wrapped prompt including BOS and
// any "start_of_turn" structure required by the model.
std::vector<int> Apply(size_t pos, const std::vector<int>& ids) const;
std::vector<int> WrapPali(const std::vector<int>& text_part,
size_t image_batch_size) const;
std::vector<int> WrapVLM(const std::vector<int>& text_part,
size_t image_batch_size) const;
private:
std::vector<int> sot_user_;
std::vector<int> sot_model_;
std::vector<int> eot_;
std::vector<int> pali_sep_;
std::vector<int> vlm_soi_;
std::vector<int> vlm_eoi_;
};
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
const GemmaChatTemplate& chat_template,
const ModelInfo& info, size_t pos,
const std::string& prompt);
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
const GemmaChatTemplate& chat_template,
const ModelInfo& info, size_t pos,
const std::string& prompt,
size_t image_batch_size);
} // namespace gcpp } // namespace gcpp