mirror of https://github.com/google/gemma.cpp.git
Merge pull request #529 from ufownl:refactor/wrap_and_tokenize
PiperOrigin-RevId: 745174371
This commit is contained in:
commit
bef91a3f03
|
|
@ -69,7 +69,8 @@ class GemmaEnv {
|
|||
}
|
||||
|
||||
std::vector<int> WrapAndTokenize(std::string& input) const {
|
||||
return gcpp::WrapAndTokenize(model_->Tokenizer(), model_->Info(), 0, input);
|
||||
return gcpp::WrapAndTokenize(model_->Tokenizer(), model_->ChatTemplate(),
|
||||
model_->Info(), 0, input);
|
||||
}
|
||||
|
||||
std::string StringFromTokens(const std::vector<int>& tokens) const {
|
||||
|
|
|
|||
|
|
@ -178,22 +178,25 @@ TEST_F(GemmaTest, Multiturn) {
|
|||
TimingInfo timing_info{.verbosity = 0};
|
||||
// First "say" something slightly unusual.
|
||||
std::string mutable_prompt = "I have a car and its color is turquoise.";
|
||||
std::vector<int> tokens = WrapAndTokenize(model->Tokenizer(), model->Info(),
|
||||
abs_pos, mutable_prompt);
|
||||
std::vector<int> tokens =
|
||||
WrapAndTokenize(model->Tokenizer(), model->ChatTemplate(), model->Info(),
|
||||
abs_pos, mutable_prompt);
|
||||
|
||||
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
|
||||
timing_info);
|
||||
// Note: we do not rewind any <end_of_turn> tokens here. If the model
|
||||
// produced one and WrapAndTokenize() inserts another one, it will just be
|
||||
// duplicated.
|
||||
mutable_prompt = "Please repeat all prior statements.";
|
||||
tokens = WrapAndTokenize(model->Tokenizer(), model->Info(), abs_pos,
|
||||
mutable_prompt);
|
||||
tokens = WrapAndTokenize(model->Tokenizer(), model->ChatTemplate(),
|
||||
model->Info(), abs_pos, mutable_prompt);
|
||||
|
||||
// Reset the `response` string here, then check that the model actually has
|
||||
// access to the previous turn by asking to reproduce.
|
||||
response.clear();
|
||||
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
|
||||
timing_info);
|
||||
fprintf(stderr, "decoded: %s\n", response.c_str());
|
||||
fprintf(stderr, "decoded: '%s'\n", response.c_str());
|
||||
bool remembered_turquoise =
|
||||
response.find("turquoise") != std::string::npos; // NOLINT
|
||||
bool remembered_car = response.find("car") != std::string::npos; // NOLINT
|
||||
|
|
|
|||
|
|
@ -74,8 +74,9 @@ int main(int argc, char** argv) {
|
|||
|
||||
// Tokenize instructions.
|
||||
std::string prompt = "Write a greeting to the world.";
|
||||
const std::vector<int> tokens = gcpp::WrapAndTokenize(
|
||||
model.Tokenizer(), loader.Info(), generated, prompt);
|
||||
const std::vector<int> tokens =
|
||||
gcpp::WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(),
|
||||
loader.Info(), generated, prompt);
|
||||
const size_t prompt_size = tokens.size();
|
||||
|
||||
// This callback function gets invoked every time a token is generated
|
||||
|
|
|
|||
|
|
@ -72,7 +72,8 @@ class SimplifiedGemma {
|
|||
size_t generated = 0;
|
||||
|
||||
const std::vector<int> tokens = gcpp::WrapAndTokenize(
|
||||
model_.Tokenizer(), loader_.Info(), generated, prompt);
|
||||
model_.Tokenizer(), model_.ChatTemplate(), loader_.Info(),
|
||||
generated, prompt);
|
||||
const size_t prompt_size = tokens.size();
|
||||
|
||||
// This callback function gets invoked every time a token is generated
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@
|
|||
#include <vector>
|
||||
|
||||
#include "util/basics.h" // BF16
|
||||
// TODO: change include when PromptWrapping is moved.
|
||||
#include "compression/shared.h" // PromptWrapping
|
||||
#include "hwy/base.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
|
@ -79,7 +81,7 @@ constexpr PromptWrapping kPromptWrapping[] = {
|
|||
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 3B 224/448
|
||||
PromptWrapping::PALIGEMMA, PromptWrapping::PALIGEMMA, // PG2 10B 224/448
|
||||
PromptWrapping::GEMMA_VLM, // Gemma3 4B
|
||||
PromptWrapping::GEMMA_PT, // Gemma3 1B
|
||||
PromptWrapping::GEMMA_IT, // Gemma3 1B
|
||||
PromptWrapping::GEMMA_VLM, // Gemma3 12B
|
||||
PromptWrapping::GEMMA_VLM, // Gemma3 27B
|
||||
};
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ const char* ModelString(Model model, PromptWrapping wrapping);
|
|||
const char* StringFromType(Type type);
|
||||
|
||||
// Wraps the given prompt using the expected control tokens for IT models.
|
||||
// `GemmaChatTemplate` is preferred if a tokenized return value is fine.
|
||||
void Wrap(const ModelInfo& info, size_t pos, std::string& prompt);
|
||||
|
||||
// Returns the scale value to use for the embedding (basically sqrt model_dim).
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
|
|||
model_.Load(weights, info.model, info.weight, info.wrapping,
|
||||
env_.parallel.Pools().Pool(0),
|
||||
/*tokenizer_proto=*/nullptr);
|
||||
chat_template_.Init(tokenizer_, model_.Config().model);
|
||||
}
|
||||
|
||||
Gemma::Gemma(const Path& weights, MatMulEnv& env) : env_(env) {
|
||||
|
|
@ -51,10 +52,13 @@ Gemma::Gemma(const Path& weights, MatMulEnv& env) : env_(env) {
|
|||
model_.Load(weights, Model::UNKNOWN, Type::kUnknown, PromptWrapping::GEMMA_IT,
|
||||
env_.parallel.Pools().Pool(0), &tokenizer_proto);
|
||||
tokenizer_.Deserialize(tokenizer_proto);
|
||||
chat_template_.Init(tokenizer_, model_.Config().model);
|
||||
}
|
||||
|
||||
Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, MatMulEnv& env)
|
||||
: env_(env), tokenizer_(std::move(tokenizer)) {
|
||||
: env_(env),
|
||||
tokenizer_(std::move(tokenizer)),
|
||||
chat_template_(tokenizer_, info.model) {
|
||||
HWY_ASSERT(info.weight == Type::kF32);
|
||||
model_.Allocate(info.model, info.weight, env_.parallel.Pools().Pool(0));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -213,6 +213,7 @@ class Gemma {
|
|||
.weight = model_.Config().weight});
|
||||
}
|
||||
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
|
||||
const GemmaChatTemplate& ChatTemplate() const { return chat_template_; }
|
||||
const ModelWeightsStorage& Weights() const { return model_; }
|
||||
ModelWeightsStorage& MutableWeights() { return model_; }
|
||||
void Save(const Path& weights, hwy::ThreadPool& pool) {
|
||||
|
|
@ -256,6 +257,7 @@ class Gemma {
|
|||
MatMulEnv& env_;
|
||||
|
||||
GemmaTokenizer tokenizer_;
|
||||
GemmaChatTemplate chat_template_;
|
||||
// Type-erased so that this can be defined in the header.
|
||||
ModelWeightsStorage model_;
|
||||
};
|
||||
|
|
|
|||
34
gemma/run.cc
34
gemma/run.cc
|
|
@ -162,16 +162,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
|||
continue;
|
||||
}
|
||||
|
||||
// Wrap, tokenize and maybe log prompt tokens.
|
||||
std::vector<int> prompt = WrapAndTokenize(
|
||||
model.Tokenizer(), model.Info(), abs_pos, prompt_string);
|
||||
prompt_size = prompt.size();
|
||||
if constexpr (kVerboseLogTokens) {
|
||||
for (int i = 0; i < prompt_size; ++i) {
|
||||
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Set up runtime config.
|
||||
TimingInfo timing_info = {.verbosity = app.verbosity};
|
||||
RuntimeConfig runtime_config = {.gen = &gen,
|
||||
|
|
@ -181,23 +171,29 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
|||
.use_spinning = app.spin};
|
||||
args.CopyTo(runtime_config);
|
||||
size_t prefix_end = 0;
|
||||
|
||||
std::vector<int> prompt;
|
||||
if (have_image) {
|
||||
prompt =
|
||||
WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), model.Info(),
|
||||
abs_pos, prompt_string, image_tokens.BatchSize());
|
||||
runtime_config.image_tokens = &image_tokens;
|
||||
if (model.Info().wrapping == PromptWrapping::PALIGEMMA) {
|
||||
prompt.insert(prompt.begin(), image_tokens.BatchSize(), 0);
|
||||
} else if (model.Info().wrapping == PromptWrapping::GEMMA_VLM) {
|
||||
size_t seq_len = model.GetModelConfig().vit_config.seq_len;
|
||||
size_t pool_dim = model.GetModelConfig().vit_config.pool_dim;
|
||||
prompt =
|
||||
WrapVLM(model.Tokenizer(), model.Info(), abs_pos, prompt,
|
||||
image_tokens.BatchSize(), seq_len / (pool_dim * pool_dim));
|
||||
}
|
||||
prompt_size = prompt.size();
|
||||
// The end of the prefix for prefix-LM style attention in Paligemma.
|
||||
// See Figure 2 of https://arxiv.org/abs/2407.07726.
|
||||
prefix_end = prompt_size;
|
||||
// We need to look at all the tokens for the prefix.
|
||||
runtime_config.prefill_tbatch_size = prompt_size;
|
||||
} else {
|
||||
prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(),
|
||||
model.Info(), abs_pos, prompt_string);
|
||||
prompt_size = prompt.size();
|
||||
}
|
||||
|
||||
if constexpr (kVerboseLogTokens) {
|
||||
for (int i = 0; i < prompt_size; ++i) {
|
||||
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Generate until EOS or max_generated_tokens.
|
||||
|
|
|
|||
|
|
@ -100,71 +100,123 @@ void GemmaTokenizer::Deserialize(const std::string& tokenizer_proto) {
|
|||
|
||||
bool GemmaTokenizer::Encode(const std::string& input,
|
||||
std::vector<std::string>* pieces) const {
|
||||
return impl_->Encode(input, pieces);
|
||||
return impl_ && impl_->Encode(input, pieces);
|
||||
}
|
||||
|
||||
bool GemmaTokenizer::Encode(const std::string& input,
|
||||
std::vector<int>* ids) const {
|
||||
return impl_->Encode(input, ids);
|
||||
return impl_ && impl_->Encode(input, ids);
|
||||
}
|
||||
|
||||
// Given a sequence of ids, decodes it into a detokenized output.
|
||||
bool GemmaTokenizer::Decode(const std::vector<int>& ids,
|
||||
std::string* detokenized) const {
|
||||
return impl_->Decode(ids, detokenized);
|
||||
return impl_ && impl_->Decode(ids, detokenized);
|
||||
}
|
||||
|
||||
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
||||
const ModelInfo& info, size_t pos,
|
||||
std::string& prompt) {
|
||||
Wrap(info, pos, prompt);
|
||||
bool GemmaChatTemplate::Init(const GemmaTokenizer& tokenizer, Model model) {
|
||||
sot_user_.reserve(3);
|
||||
if (!tokenizer.Encode("<start_of_turn>user\n", &sot_user_)) return false;
|
||||
sot_model_.reserve(3);
|
||||
HWY_ASSERT(tokenizer.Encode("<start_of_turn>model\n", &sot_model_));
|
||||
eot_.reserve(2);
|
||||
HWY_ASSERT(tokenizer.Encode("<end_of_turn>\n", &eot_));
|
||||
|
||||
HWY_ASSERT(tokenizer.Encode("\n", &pali_sep_));
|
||||
vlm_soi_.reserve(2);
|
||||
HWY_ASSERT(tokenizer.Encode("\n\n<start_of_image>", &vlm_soi_));
|
||||
vlm_eoi_.reserve(2);
|
||||
HWY_ASSERT(tokenizer.Encode("<end_of_image>\n\n", &vlm_eoi_));
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<int> GemmaChatTemplate::Apply(size_t pos,
|
||||
const std::vector<int>& ids) const {
|
||||
HWY_ASSERT_M(!sot_user_.empty() && !sot_model_.empty() && !eot_.empty(),
|
||||
"GemmaChatTemplate has not been initialized.");
|
||||
std::vector<int> out;
|
||||
out.reserve(eot_.size() + sot_user_.size() + ids.size() + eot_.size() +
|
||||
sot_model_.size());
|
||||
|
||||
// Start with BOS, or prepend end_of_turn if this is a continuation.
|
||||
if (pos == 0) {
|
||||
out.push_back(BOS_ID);
|
||||
} else {
|
||||
out.insert(out.cend(), eot_.cbegin(), eot_.cend());
|
||||
}
|
||||
// Start of user turn, user prompt, end of turn; then start of model turn.
|
||||
out.insert(out.cend(), sot_user_.cbegin(), sot_user_.cend());
|
||||
out.insert(out.cend(), ids.cbegin(), ids.cend());
|
||||
out.insert(out.cend(), eot_.cbegin(), eot_.cend());
|
||||
out.insert(out.cend(), sot_model_.cbegin(), sot_model_.cend());
|
||||
return out;
|
||||
}
|
||||
|
||||
std::vector<int> GemmaChatTemplate::WrapPali(const std::vector<int>& text_part,
|
||||
size_t image_batch_size) const {
|
||||
HWY_ASSERT_M(!pali_sep_.empty(),
|
||||
"GemmaChatTemplate has not been initialized.");
|
||||
std::vector<int> out;
|
||||
out.reserve(image_batch_size + 1 + text_part.size() + pali_sep_.size());
|
||||
out.resize(image_batch_size, 0);
|
||||
out.push_back(BOS_ID);
|
||||
out.insert(out.cend(), text_part.cbegin(), text_part.cend());
|
||||
out.insert(out.cend(), pali_sep_.cbegin(), pali_sep_.cend());
|
||||
return out;
|
||||
}
|
||||
|
||||
std::vector<int> GemmaChatTemplate::WrapVLM(const std::vector<int>& text_part,
|
||||
size_t image_batch_size) const {
|
||||
HWY_ASSERT_M(!vlm_soi_.empty() && !vlm_eoi_.empty(),
|
||||
"GemmaChatTemplate has not been initialized.");
|
||||
std::vector<int> out;
|
||||
out.reserve(text_part.size() + vlm_soi_.size() + image_batch_size +
|
||||
vlm_eoi_.size());
|
||||
out.insert(out.cend(), text_part.cbegin(), text_part.cend());
|
||||
out.insert(out.cend(), vlm_soi_.cbegin(), vlm_soi_.cend());
|
||||
out.insert(out.cend(), image_batch_size, -2);
|
||||
out.insert(out.cend(), vlm_eoi_.cbegin(), vlm_eoi_.cend());
|
||||
return out;
|
||||
}
|
||||
|
||||
// Text
|
||||
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
||||
const GemmaChatTemplate& chat_template,
|
||||
const ModelInfo& info, size_t pos,
|
||||
const std::string& prompt) {
|
||||
std::vector<int> tokens;
|
||||
HWY_ASSERT(tokenizer.Encode(prompt, &tokens));
|
||||
// Both pre-trained and instruction-tuned require BOS as first token.
|
||||
if (pos == 0) {
|
||||
tokens.insert(tokens.begin(), BOS_ID);
|
||||
}
|
||||
|
||||
// PaliGemma separator. The SEP token "\n" is always tokenized separately.
|
||||
if (info.wrapping == PromptWrapping::PALIGEMMA
|
||||
// || info.wrapping == PromptWrapping::GEMMA_VLM
|
||||
) {
|
||||
std::vector<int> sep_tokens;
|
||||
HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens));
|
||||
tokens.insert(tokens.end(), sep_tokens.begin(), sep_tokens.end());
|
||||
switch (info.wrapping) {
|
||||
case PromptWrapping::GEMMA_IT:
|
||||
case PromptWrapping::GEMMA_VLM:
|
||||
return chat_template.Apply(pos, tokens);
|
||||
default:
|
||||
if (pos == 0) {
|
||||
tokens.insert(tokens.cbegin(), BOS_ID);
|
||||
}
|
||||
return tokens;
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
std::vector<int> WrapVLM(const GemmaTokenizer& tokenizer, const ModelInfo& info,
|
||||
size_t pos, std::vector<int>& tokens,
|
||||
size_t image_batch_size, size_t max_image_batch_size) {
|
||||
HWY_ASSERT(info.wrapping == PromptWrapping::GEMMA_VLM);
|
||||
size_t num_images = hwy::DivCeil(image_batch_size, max_image_batch_size);
|
||||
|
||||
std::vector<int> sep_tokens;
|
||||
HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens));
|
||||
|
||||
std::string begin_image_prompt = "\n\n<start_of_image>";
|
||||
std::vector<int> begin_image_tokens =
|
||||
WrapAndTokenize(tokenizer, info, pos, begin_image_prompt);
|
||||
|
||||
std::string end_image_prompt = "<end_of_image>\n\n";
|
||||
std::vector<int> end_image_tokens =
|
||||
WrapAndTokenize(tokenizer, info, pos, end_image_prompt);
|
||||
|
||||
for (size_t i = 0; i < num_images; ++i) {
|
||||
tokens.insert(tokens.begin(), begin_image_tokens.begin(),
|
||||
begin_image_tokens.end());
|
||||
tokens.insert(tokens.begin() + begin_image_tokens.size(), image_batch_size,
|
||||
-2);
|
||||
tokens.insert(tokens.begin() + begin_image_tokens.size() + image_batch_size,
|
||||
end_image_tokens.begin(), end_image_tokens.end());
|
||||
// Vision
|
||||
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
||||
const GemmaChatTemplate& chat_template,
|
||||
const ModelInfo& info, size_t pos,
|
||||
const std::string& prompt,
|
||||
size_t image_batch_size) {
|
||||
std::vector<int> text_part;
|
||||
HWY_ASSERT(tokenizer.Encode(prompt, &text_part));
|
||||
switch (info.wrapping) {
|
||||
case PromptWrapping::PALIGEMMA:
|
||||
HWY_ASSERT(pos == 0);
|
||||
return chat_template.WrapPali(text_part, image_batch_size);
|
||||
case PromptWrapping::GEMMA_VLM:
|
||||
return chat_template.Apply(
|
||||
pos, chat_template.WrapVLM(text_part, image_batch_size));
|
||||
default:
|
||||
HWY_ASSERT_M(false, "Current variant does not support vision prompt.");
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -54,13 +54,43 @@ class GemmaTokenizer {
|
|||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
||||
const ModelInfo& info, size_t pos,
|
||||
std::string& prompt);
|
||||
class GemmaChatTemplate {
|
||||
public:
|
||||
GemmaChatTemplate() = default;
|
||||
explicit GemmaChatTemplate(const GemmaTokenizer& tokenizer, Model model) {
|
||||
(void)Init(tokenizer, model);
|
||||
}
|
||||
|
||||
std::vector<int> WrapVLM(const GemmaTokenizer& tokenizer, const ModelInfo& info,
|
||||
size_t pos, std::vector<int>& tokens,
|
||||
size_t image_batch_size, size_t max_image_batch_size);
|
||||
// Returns false if the tokenizer is not available (as in optimize_test.cc).
|
||||
bool Init(const GemmaTokenizer& tokenizer, Model model);
|
||||
|
||||
// Given prompt tokens, this returns the wrapped prompt including BOS and
|
||||
// any "start_of_turn" structure required by the model.
|
||||
std::vector<int> Apply(size_t pos, const std::vector<int>& ids) const;
|
||||
std::vector<int> WrapPali(const std::vector<int>& text_part,
|
||||
size_t image_batch_size) const;
|
||||
std::vector<int> WrapVLM(const std::vector<int>& text_part,
|
||||
size_t image_batch_size) const;
|
||||
|
||||
private:
|
||||
std::vector<int> sot_user_;
|
||||
std::vector<int> sot_model_;
|
||||
std::vector<int> eot_;
|
||||
std::vector<int> pali_sep_;
|
||||
std::vector<int> vlm_soi_;
|
||||
std::vector<int> vlm_eoi_;
|
||||
};
|
||||
|
||||
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
||||
const GemmaChatTemplate& chat_template,
|
||||
const ModelInfo& info, size_t pos,
|
||||
const std::string& prompt);
|
||||
|
||||
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
||||
const GemmaChatTemplate& chat_template,
|
||||
const ModelInfo& info, size_t pos,
|
||||
const std::string& prompt,
|
||||
size_t image_batch_size);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue