Merge pull request #529 from ufownl:refactor/wrap_and_tokenize

PiperOrigin-RevId: 745174371
This commit is contained in:
Copybara-Service 2025-04-08 09:22:26 -07:00
commit bef91a3f03
11 changed files with 176 additions and 83 deletions

View File

@ -69,7 +69,8 @@ class GemmaEnv {
}
std::vector<int> WrapAndTokenize(std::string& input) const {
return gcpp::WrapAndTokenize(model_->Tokenizer(), model_->Info(), 0, input);
return gcpp::WrapAndTokenize(model_->Tokenizer(), model_->ChatTemplate(),
model_->Info(), 0, input);
}
std::string StringFromTokens(const std::vector<int>& tokens) const {

View File

@ -178,22 +178,25 @@ TEST_F(GemmaTest, Multiturn) {
TimingInfo timing_info{.verbosity = 0};
// First "say" something slightly unusual.
std::string mutable_prompt = "I have a car and its color is turquoise.";
std::vector<int> tokens = WrapAndTokenize(model->Tokenizer(), model->Info(),
abs_pos, mutable_prompt);
std::vector<int> tokens =
WrapAndTokenize(model->Tokenizer(), model->ChatTemplate(), model->Info(),
abs_pos, mutable_prompt);
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
timing_info);
// Note: we do not rewind any <end_of_turn> tokens here. If the model
// produced one and WrapAndTokenize() inserts another one, it will just be
// duplicated.
mutable_prompt = "Please repeat all prior statements.";
tokens = WrapAndTokenize(model->Tokenizer(), model->Info(), abs_pos,
mutable_prompt);
tokens = WrapAndTokenize(model->Tokenizer(), model->ChatTemplate(),
model->Info(), abs_pos, mutable_prompt);
// Reset the `response` string here, then check that the model actually has
// access to the previous turn by asking to reproduce.
response.clear();
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
timing_info);
fprintf(stderr, "decoded: %s\n", response.c_str());
fprintf(stderr, "decoded: '%s'\n", response.c_str());
bool remembered_turquoise =
response.find("turquoise") != std::string::npos; // NOLINT
bool remembered_car = response.find("car") != std::string::npos; // NOLINT

View File

@ -74,8 +74,9 @@ int main(int argc, char** argv) {
// Tokenize instructions.
std::string prompt = "Write a greeting to the world.";
const std::vector<int> tokens = gcpp::WrapAndTokenize(
model.Tokenizer(), loader.Info(), generated, prompt);
const std::vector<int> tokens =
gcpp::WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(),
loader.Info(), generated, prompt);
const size_t prompt_size = tokens.size();
// This callback function gets invoked every time a token is generated

View File

@ -72,7 +72,8 @@ class SimplifiedGemma {
size_t generated = 0;
const std::vector<int> tokens = gcpp::WrapAndTokenize(
model_.Tokenizer(), loader_.Info(), generated, prompt);
model_.Tokenizer(), model_.ChatTemplate(), loader_.Info(),
generated, prompt);
const size_t prompt_size = tokens.size();
// This callback function gets invoked every time a token is generated
@ -115,4 +116,4 @@ class SimplifiedGemma {
gcpp::KVCache kv_cache_;
std::mt19937 gen_;
std::string validation_error_;
};
};

View File

@ -25,6 +25,8 @@
#include <vector>
#include "util/basics.h" // BF16
// TODO: change include when PromptWrapping is moved.
#include "compression/shared.h" // PromptWrapping
#include "hwy/base.h"
namespace gcpp {
@ -79,7 +81,7 @@ constexpr PromptWrapping kPromptWrapping[] = {
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 3B 224/448
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 10B 224/448
PromptWrapping::GEMMA_VLM, // Gemma3 4B
PromptWrapping::GEMMA_PT, // Gemma3 1B
PromptWrapping::GEMMA_IT, // Gemma3 1B
PromptWrapping::GEMMA_VLM, // Gemma3 12B
PromptWrapping::GEMMA_VLM, // Gemma3 27B
};

View File

@ -44,6 +44,7 @@ const char* ModelString(Model model, PromptWrapping wrapping);
const char* StringFromType(Type type);
// Wraps the given prompt using the expected control tokens for IT models.
// `GemmaChatTemplate` is preferred if a tokenized return value is fine.
void Wrap(const ModelInfo& info, size_t pos, std::string& prompt);
// Returns the scale value to use for the embedding (basically sqrt model_dim).

View File

@ -44,6 +44,7 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
model_.Load(weights, info.model, info.weight, info.wrapping,
env_.parallel.Pools().Pool(0),
/*tokenizer_proto=*/nullptr);
chat_template_.Init(tokenizer_, model_.Config().model);
}
Gemma::Gemma(const Path& weights, MatMulEnv& env) : env_(env) {
@ -51,10 +52,13 @@ Gemma::Gemma(const Path& weights, MatMulEnv& env) : env_(env) {
model_.Load(weights, Model::UNKNOWN, Type::kUnknown, PromptWrapping::GEMMA_IT,
env_.parallel.Pools().Pool(0), &tokenizer_proto);
tokenizer_.Deserialize(tokenizer_proto);
chat_template_.Init(tokenizer_, model_.Config().model);
}
Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, MatMulEnv& env)
: env_(env), tokenizer_(std::move(tokenizer)) {
: env_(env),
tokenizer_(std::move(tokenizer)),
chat_template_(tokenizer_, info.model) {
HWY_ASSERT(info.weight == Type::kF32);
model_.Allocate(info.model, info.weight, env_.parallel.Pools().Pool(0));
}

View File

@ -213,6 +213,7 @@ class Gemma {
.weight = model_.Config().weight});
}
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
const GemmaChatTemplate& ChatTemplate() const { return chat_template_; }
const ModelWeightsStorage& Weights() const { return model_; }
ModelWeightsStorage& MutableWeights() { return model_; }
void Save(const Path& weights, hwy::ThreadPool& pool) {
@ -256,6 +257,7 @@ class Gemma {
MatMulEnv& env_;
GemmaTokenizer tokenizer_;
GemmaChatTemplate chat_template_;
// Type-erased so that this can be defined in the header.
ModelWeightsStorage model_;
};

View File

@ -162,16 +162,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
continue;
}
// Wrap, tokenize and maybe log prompt tokens.
std::vector<int> prompt = WrapAndTokenize(
model.Tokenizer(), model.Info(), abs_pos, prompt_string);
prompt_size = prompt.size();
if constexpr (kVerboseLogTokens) {
for (int i = 0; i < prompt_size; ++i) {
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
}
}
// Set up runtime config.
TimingInfo timing_info = {.verbosity = app.verbosity};
RuntimeConfig runtime_config = {.gen = &gen,
@ -181,23 +171,29 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
.use_spinning = app.spin};
args.CopyTo(runtime_config);
size_t prefix_end = 0;
std::vector<int> prompt;
if (have_image) {
prompt =
WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(),
abs_pos, prompt_string, image_tokens.BatchSize());
runtime_config.image_tokens = &image_tokens;
if (model.Info().wrapping == PromptWrapping::PALIGEMMA) {
prompt.insert(prompt.begin(), image_tokens.BatchSize(), 0);
} else if (model.Info().wrapping == PromptWrapping::GEMMA_VLM) {
size_t seq_len = model.GetModelConfig().vit_config.seq_len;
size_t pool_dim = model.GetModelConfig().vit_config.pool_dim;
prompt =
WrapVLM(model.Tokenizer(), model.Info(), abs_pos, prompt,
image_tokens.BatchSize(), seq_len / (pool_dim * pool_dim));
}
prompt_size = prompt.size();
// The end of the prefix for prefix-LM style attention in Paligemma.
// See Figure 2 of https://arxiv.org/abs/2407.07726.
prefix_end = prompt_size;
// We need to look at all the tokens for the prefix.
runtime_config.prefill_tbatch_size = prompt_size;
} else {
prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(),
model.Info(), abs_pos, prompt_string);
prompt_size = prompt.size();
}
if constexpr (kVerboseLogTokens) {
for (int i = 0; i < prompt_size; ++i) {
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
}
}
// Generate until EOS or max_generated_tokens.

View File

@ -100,71 +100,123 @@ void GemmaTokenizer::Deserialize(const std::string& tokenizer_proto) {
bool GemmaTokenizer::Encode(const std::string& input,
std::vector<std::string>* pieces) const {
return impl_->Encode(input, pieces);
return impl_ && impl_->Encode(input, pieces);
}
bool GemmaTokenizer::Encode(const std::string& input,
std::vector<int>* ids) const {
return impl_->Encode(input, ids);
return impl_ && impl_->Encode(input, ids);
}
// Given a sequence of ids, decodes it into a detokenized output.
bool GemmaTokenizer::Decode(const std::vector<int>& ids,
std::string* detokenized) const {
return impl_->Decode(ids, detokenized);
return impl_ && impl_->Decode(ids, detokenized);
}
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
const ModelInfo& info, size_t pos,
std::string& prompt) {
Wrap(info, pos, prompt);
bool GemmaChatTemplate::Init(const GemmaTokenizer& tokenizer, Model model) {
sot_user_.reserve(3);
if (!tokenizer.Encode("<start_of_turn>user\n", &sot_user_)) return false;
sot_model_.reserve(3);
HWY_ASSERT(tokenizer.Encode("<start_of_turn>model\n", &sot_model_));
eot_.reserve(2);
HWY_ASSERT(tokenizer.Encode("<end_of_turn>\n", &eot_));
HWY_ASSERT(tokenizer.Encode("\n", &pali_sep_));
vlm_soi_.reserve(2);
HWY_ASSERT(tokenizer.Encode("\n\n<start_of_image>", &vlm_soi_));
vlm_eoi_.reserve(2);
HWY_ASSERT(tokenizer.Encode("<end_of_image>\n\n", &vlm_eoi_));
return true;
}
std::vector<int> GemmaChatTemplate::Apply(size_t pos,
const std::vector<int>& ids) const {
HWY_ASSERT_M(!sot_user_.empty() && !sot_model_.empty() && !eot_.empty(),
"GemmaChatTemplate has not been initialized.");
std::vector<int> out;
out.reserve(eot_.size() + sot_user_.size() + ids.size() + eot_.size() +
sot_model_.size());
// Start with BOS, or prepend end_of_turn if this is a continuation.
if (pos == 0) {
out.push_back(BOS_ID);
} else {
out.insert(out.cend(), eot_.cbegin(), eot_.cend());
}
// Start of user turn, user prompt, end of turn; then start of model turn.
out.insert(out.cend(), sot_user_.cbegin(), sot_user_.cend());
out.insert(out.cend(), ids.cbegin(), ids.cend());
out.insert(out.cend(), eot_.cbegin(), eot_.cend());
out.insert(out.cend(), sot_model_.cbegin(), sot_model_.cend());
return out;
}
std::vector<int> GemmaChatTemplate::WrapPali(const std::vector<int>& text_part,
size_t image_batch_size) const {
HWY_ASSERT_M(!pali_sep_.empty(),
"GemmaChatTemplate has not been initialized.");
std::vector<int> out;
out.reserve(image_batch_size + 1 + text_part.size() + pali_sep_.size());
out.resize(image_batch_size, 0);
out.push_back(BOS_ID);
out.insert(out.cend(), text_part.cbegin(), text_part.cend());
out.insert(out.cend(), pali_sep_.cbegin(), pali_sep_.cend());
return out;
}
std::vector<int> GemmaChatTemplate::WrapVLM(const std::vector<int>& text_part,
size_t image_batch_size) const {
HWY_ASSERT_M(!vlm_soi_.empty() && !vlm_eoi_.empty(),
"GemmaChatTemplate has not been initialized.");
std::vector<int> out;
out.reserve(text_part.size() + vlm_soi_.size() + image_batch_size +
vlm_eoi_.size());
out.insert(out.cend(), text_part.cbegin(), text_part.cend());
out.insert(out.cend(), vlm_soi_.cbegin(), vlm_soi_.cend());
out.insert(out.cend(), image_batch_size, -2);
out.insert(out.cend(), vlm_eoi_.cbegin(), vlm_eoi_.cend());
return out;
}
// Text
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
const GemmaChatTemplate& chat_template,
const ModelInfo& info, size_t pos,
const std::string& prompt) {
std::vector<int> tokens;
HWY_ASSERT(tokenizer.Encode(prompt, &tokens));
// Both pre-trained and instruction-tuned require BOS as first token.
if (pos == 0) {
tokens.insert(tokens.begin(), BOS_ID);
}
// PaliGemma separator. The SEP token "\n" is always tokenized separately.
if (info.wrapping == PromptWrapping::PALIGEMMA
// || info.wrapping == PromptWrapping::GEMMA_VLM
) {
std::vector<int> sep_tokens;
HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens));
tokens.insert(tokens.end(), sep_tokens.begin(), sep_tokens.end());
switch (info.wrapping) {
case PromptWrapping::GEMMA_IT:
case PromptWrapping::GEMMA_VLM:
return chat_template.Apply(pos, tokens);
default:
if (pos == 0) {
tokens.insert(tokens.cbegin(), BOS_ID);
}
return tokens;
}
return tokens;
}
std::vector<int> WrapVLM(const GemmaTokenizer& tokenizer, const ModelInfo& info,
size_t pos, std::vector<int>& tokens,
size_t image_batch_size, size_t max_image_batch_size) {
HWY_ASSERT(info.wrapping == PromptWrapping::GEMMA_VLM);
size_t num_images = hwy::DivCeil(image_batch_size, max_image_batch_size);
std::vector<int> sep_tokens;
HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens));
std::string begin_image_prompt = "\n\n<start_of_image>";
std::vector<int> begin_image_tokens =
WrapAndTokenize(tokenizer, info, pos, begin_image_prompt);
std::string end_image_prompt = "<end_of_image>\n\n";
std::vector<int> end_image_tokens =
WrapAndTokenize(tokenizer, info, pos, end_image_prompt);
for (size_t i = 0; i < num_images; ++i) {
tokens.insert(tokens.begin(), begin_image_tokens.begin(),
begin_image_tokens.end());
tokens.insert(tokens.begin() + begin_image_tokens.size(), image_batch_size,
-2);
tokens.insert(tokens.begin() + begin_image_tokens.size() + image_batch_size,
end_image_tokens.begin(), end_image_tokens.end());
// Vision
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
const GemmaChatTemplate& chat_template,
const ModelInfo& info, size_t pos,
const std::string& prompt,
size_t image_batch_size) {
std::vector<int> text_part;
HWY_ASSERT(tokenizer.Encode(prompt, &text_part));
switch (info.wrapping) {
case PromptWrapping::PALIGEMMA:
HWY_ASSERT(pos == 0);
return chat_template.WrapPali(text_part, image_batch_size);
case PromptWrapping::GEMMA_VLM:
return chat_template.Apply(
pos, chat_template.WrapVLM(text_part, image_batch_size));
default:
HWY_ASSERT_M(false, "Current variant does not support vision prompt.");
}
return tokens;
}
} // namespace gcpp

View File

@ -54,13 +54,43 @@ class GemmaTokenizer {
std::unique_ptr<Impl> impl_;
};
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
const ModelInfo& info, size_t pos,
std::string& prompt);
class GemmaChatTemplate {
public:
GemmaChatTemplate() = default;
explicit GemmaChatTemplate(const GemmaTokenizer& tokenizer, Model model) {
(void)Init(tokenizer, model);
}
std::vector<int> WrapVLM(const GemmaTokenizer& tokenizer, const ModelInfo& info,
size_t pos, std::vector<int>& tokens,
size_t image_batch_size, size_t max_image_batch_size);
// Returns false if the tokenizer is not available (as in optimize_test.cc).
bool Init(const GemmaTokenizer& tokenizer, Model model);
// Given prompt tokens, this returns the wrapped prompt including BOS and
// any "start_of_turn" structure required by the model.
std::vector<int> Apply(size_t pos, const std::vector<int>& ids) const;
std::vector<int> WrapPali(const std::vector<int>& text_part,
size_t image_batch_size) const;
std::vector<int> WrapVLM(const std::vector<int>& text_part,
size_t image_batch_size) const;
private:
std::vector<int> sot_user_;
std::vector<int> sot_model_;
std::vector<int> eot_;
std::vector<int> pali_sep_;
std::vector<int> vlm_soi_;
std::vector<int> vlm_eoi_;
};
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
const GemmaChatTemplate& chat_template,
const ModelInfo& info, size_t pos,
const std::string& prompt);
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
const GemmaChatTemplate& chat_template,
const ModelInfo& info, size_t pos,
const std::string& prompt,
size_t image_batch_size);
} // namespace gcpp