mirror of https://github.com/google/gemma.cpp.git
Refactor `WrapAndTokenize` to work properly with Gemma3
This commit is contained in:
parent
76a81ac2d6
commit
ca4ee2b63f
|
|
@ -69,7 +69,7 @@ class GemmaEnv {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> WrapAndTokenize(std::string& input) const {
|
std::vector<int> WrapAndTokenize(std::string& input) const {
|
||||||
return gcpp::WrapAndTokenize(model_->Tokenizer(), model_->Info(), 0, input);
|
return gcpp::WrapAndTokenize(model_->Tokenizer(), model_->ChatTemplate(), model_->Info(), 0, input);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string StringFromTokens(const std::vector<int>& tokens) const {
|
std::string StringFromTokens(const std::vector<int>& tokens) const {
|
||||||
|
|
|
||||||
|
|
@ -178,16 +178,18 @@ TEST_F(GemmaTest, Multiturn) {
|
||||||
TimingInfo timing_info{.verbosity = 0};
|
TimingInfo timing_info{.verbosity = 0};
|
||||||
// First "say" something slightly unusual.
|
// First "say" something slightly unusual.
|
||||||
std::string mutable_prompt = "I have a car and its color is turquoise.";
|
std::string mutable_prompt = "I have a car and its color is turquoise.";
|
||||||
std::vector<int> tokens = WrapAndTokenize(model->Tokenizer(), model->Info(),
|
std::vector<int> tokens = WrapAndTokenize(model->Tokenizer(),
|
||||||
abs_pos, mutable_prompt);
|
model->ChatTemplate(),
|
||||||
|
model->Info(), abs_pos,
|
||||||
|
mutable_prompt);
|
||||||
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
|
model->Generate(runtime_config, tokens, abs_pos, s_env->MutableKVCache(),
|
||||||
timing_info);
|
timing_info);
|
||||||
// Note: we do not rewind any <end_of_turn> tokens here. If the model
|
// Note: we do not rewind any <end_of_turn> tokens here. If the model
|
||||||
// produced one and WrapAndTokenize() inserts another one, it will just be
|
// produced one and WrapAndTokenize() inserts another one, it will just be
|
||||||
// duplicated.
|
// duplicated.
|
||||||
mutable_prompt = "Please repeat all prior statements.";
|
mutable_prompt = "Please repeat all prior statements.";
|
||||||
tokens = WrapAndTokenize(model->Tokenizer(), model->Info(), abs_pos,
|
tokens = WrapAndTokenize(model->Tokenizer(), model->ChatTemplate(),
|
||||||
mutable_prompt);
|
model->Info(), abs_pos, mutable_prompt);
|
||||||
// Reset the `response` string here, then check that the model actually has
|
// Reset the `response` string here, then check that the model actually has
|
||||||
// access to the previous turn by asking to reproduce.
|
// access to the previous turn by asking to reproduce.
|
||||||
response.clear();
|
response.clear();
|
||||||
|
|
|
||||||
|
|
@ -75,7 +75,8 @@ int main(int argc, char** argv) {
|
||||||
// Tokenize instructions.
|
// Tokenize instructions.
|
||||||
std::string prompt = "Write a greeting to the world.";
|
std::string prompt = "Write a greeting to the world.";
|
||||||
const std::vector<int> tokens = gcpp::WrapAndTokenize(
|
const std::vector<int> tokens = gcpp::WrapAndTokenize(
|
||||||
model.Tokenizer(), loader.Info(), generated, prompt);
|
model.Tokenizer(), model.ChatTemplate(), loader.Info(),
|
||||||
|
generated, prompt);
|
||||||
const size_t prompt_size = tokens.size();
|
const size_t prompt_size = tokens.size();
|
||||||
|
|
||||||
// This callback function gets invoked every time a token is generated
|
// This callback function gets invoked every time a token is generated
|
||||||
|
|
|
||||||
|
|
@ -72,7 +72,8 @@ class SimplifiedGemma {
|
||||||
size_t generated = 0;
|
size_t generated = 0;
|
||||||
|
|
||||||
const std::vector<int> tokens = gcpp::WrapAndTokenize(
|
const std::vector<int> tokens = gcpp::WrapAndTokenize(
|
||||||
model_.Tokenizer(), loader_.Info(), generated, prompt);
|
model_.Tokenizer(), model_.ChatTemplate(), loader_.Info(),
|
||||||
|
generated, prompt);
|
||||||
const size_t prompt_size = tokens.size();
|
const size_t prompt_size = tokens.size();
|
||||||
|
|
||||||
// This callback function gets invoked every time a token is generated
|
// This callback function gets invoked every time a token is generated
|
||||||
|
|
|
||||||
|
|
@ -148,18 +148,6 @@ const char* ParseType(const std::string& type_string, Type& type) {
|
||||||
return kErrorMessageBuffer.c_str();
|
return kErrorMessageBuffer.c_str();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Wrap(const ModelInfo& info, size_t pos, std::string& prompt) {
|
|
||||||
|
|
||||||
// Instruction-tuned models are trained to expect control tokens.
|
|
||||||
if (info.wrapping == PromptWrapping::GEMMA_IT) {
|
|
||||||
// Prepend "<end_of_turn>" if this is a multi-turn dialogue continuation.
|
|
||||||
const std::string start = (pos == 0)
|
|
||||||
? "<start_of_turn>user\n"
|
|
||||||
: "<end_of_turn>\n<start_of_turn>user\n";
|
|
||||||
prompt = start + prompt + "<end_of_turn>\n<start_of_turn>model\n";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
float EmbeddingScaling(size_t model_dim) {
|
float EmbeddingScaling(size_t model_dim) {
|
||||||
// Round to bf16 to match Gemma's Embedder, which casts before mul.
|
// Round to bf16 to match Gemma's Embedder, which casts before mul.
|
||||||
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
|
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
|
||||||
|
|
|
||||||
|
|
@ -43,9 +43,6 @@ const char* ParseType(const std::string& type_string, Type& type);
|
||||||
const char* ModelString(Model model, PromptWrapping wrapping);
|
const char* ModelString(Model model, PromptWrapping wrapping);
|
||||||
const char* StringFromType(Type type);
|
const char* StringFromType(Type type);
|
||||||
|
|
||||||
// Wraps the given prompt using the expected control tokens for IT models.
|
|
||||||
void Wrap(const ModelInfo& info, size_t pos, std::string& prompt);
|
|
||||||
|
|
||||||
// Returns the scale value to use for the embedding (basically sqrt model_dim).
|
// Returns the scale value to use for the embedding (basically sqrt model_dim).
|
||||||
float EmbeddingScaling(size_t model_dim);
|
float EmbeddingScaling(size_t model_dim);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,7 @@ namespace gcpp {
|
||||||
|
|
||||||
Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
|
Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
|
||||||
const ModelInfo& info, MatMulEnv& env)
|
const ModelInfo& info, MatMulEnv& env)
|
||||||
: env_(env), tokenizer_(tokenizer_path) {
|
: env_(env), tokenizer_(tokenizer_path), chat_template_(tokenizer_) {
|
||||||
model_.Load(weights, info.model, info.weight, info.wrapping,
|
model_.Load(weights, info.model, info.weight, info.wrapping,
|
||||||
env_.parallel.Pools().Pool(0),
|
env_.parallel.Pools().Pool(0),
|
||||||
/*tokenizer_proto=*/nullptr);
|
/*tokenizer_proto=*/nullptr);
|
||||||
|
|
@ -51,10 +51,11 @@ Gemma::Gemma(const Path& weights, MatMulEnv& env) : env_(env) {
|
||||||
model_.Load(weights, Model::UNKNOWN, Type::kUnknown, PromptWrapping::GEMMA_IT,
|
model_.Load(weights, Model::UNKNOWN, Type::kUnknown, PromptWrapping::GEMMA_IT,
|
||||||
env_.parallel.Pools().Pool(0), &tokenizer_proto);
|
env_.parallel.Pools().Pool(0), &tokenizer_proto);
|
||||||
tokenizer_.Deserialize(tokenizer_proto);
|
tokenizer_.Deserialize(tokenizer_proto);
|
||||||
|
chat_template_.Init(tokenizer_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, MatMulEnv& env)
|
Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, MatMulEnv& env)
|
||||||
: env_(env), tokenizer_(std::move(tokenizer)) {
|
: env_(env), tokenizer_(std::move(tokenizer)), chat_template_(tokenizer_) {
|
||||||
HWY_ASSERT(info.weight == Type::kF32);
|
HWY_ASSERT(info.weight == Type::kF32);
|
||||||
model_.Allocate(info.model, info.weight, env_.parallel.Pools().Pool(0));
|
model_.Allocate(info.model, info.weight, env_.parallel.Pools().Pool(0));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -213,6 +213,7 @@ class Gemma {
|
||||||
.weight = model_.Config().weight});
|
.weight = model_.Config().weight});
|
||||||
}
|
}
|
||||||
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
|
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
|
||||||
|
const GemmaChatTemplate& ChatTemplate() const { return chat_template_; }
|
||||||
const ModelWeightsStorage& Weights() const { return model_; }
|
const ModelWeightsStorage& Weights() const { return model_; }
|
||||||
ModelWeightsStorage& MutableWeights() { return model_; }
|
ModelWeightsStorage& MutableWeights() { return model_; }
|
||||||
void Save(const Path& weights, hwy::ThreadPool& pool) {
|
void Save(const Path& weights, hwy::ThreadPool& pool) {
|
||||||
|
|
@ -256,6 +257,7 @@ class Gemma {
|
||||||
MatMulEnv& env_;
|
MatMulEnv& env_;
|
||||||
|
|
||||||
GemmaTokenizer tokenizer_;
|
GemmaTokenizer tokenizer_;
|
||||||
|
GemmaChatTemplate chat_template_;
|
||||||
// Type-erased so that this can be defined in the header.
|
// Type-erased so that this can be defined in the header.
|
||||||
ModelWeightsStorage model_;
|
ModelWeightsStorage model_;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
33
gemma/run.cc
33
gemma/run.cc
|
|
@ -162,16 +162,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wrap, tokenize and maybe log prompt tokens.
|
std::vector<int> prompt;
|
||||||
std::vector<int> prompt = WrapAndTokenize(
|
|
||||||
model.Tokenizer(), model.Info(), abs_pos, prompt_string);
|
|
||||||
prompt_size = prompt.size();
|
|
||||||
if constexpr (kVerboseLogTokens) {
|
|
||||||
for (int i = 0; i < prompt_size; ++i) {
|
|
||||||
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up runtime config.
|
// Set up runtime config.
|
||||||
TimingInfo timing_info = {.verbosity = app.verbosity};
|
TimingInfo timing_info = {.verbosity = app.verbosity};
|
||||||
RuntimeConfig runtime_config = {.gen = &gen,
|
RuntimeConfig runtime_config = {.gen = &gen,
|
||||||
|
|
@ -182,22 +173,26 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
||||||
args.CopyTo(runtime_config);
|
args.CopyTo(runtime_config);
|
||||||
size_t prefix_end = 0;
|
size_t prefix_end = 0;
|
||||||
if (have_image) {
|
if (have_image) {
|
||||||
|
prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(),
|
||||||
|
model.Info(), abs_pos, prompt_string,
|
||||||
|
image_tokens.BatchSize());
|
||||||
runtime_config.image_tokens = &image_tokens;
|
runtime_config.image_tokens = &image_tokens;
|
||||||
if (model.Info().wrapping == PromptWrapping::PALIGEMMA) {
|
|
||||||
prompt.insert(prompt.begin(), image_tokens.BatchSize(), 0);
|
|
||||||
} else if (model.Info().wrapping == PromptWrapping::GEMMA_VLM) {
|
|
||||||
size_t seq_len = model.GetModelConfig().vit_config.seq_len;
|
|
||||||
size_t pool_dim = model.GetModelConfig().vit_config.pool_dim;
|
|
||||||
prompt =
|
|
||||||
WrapVLM(model.Tokenizer(), model.Info(), abs_pos, prompt,
|
|
||||||
image_tokens.BatchSize(), seq_len / (pool_dim * pool_dim));
|
|
||||||
}
|
|
||||||
prompt_size = prompt.size();
|
prompt_size = prompt.size();
|
||||||
// The end of the prefix for prefix-LM style attention in Paligemma.
|
// The end of the prefix for prefix-LM style attention in Paligemma.
|
||||||
// See Figure 2 of https://arxiv.org/abs/2407.07726.
|
// See Figure 2 of https://arxiv.org/abs/2407.07726.
|
||||||
prefix_end = prompt_size;
|
prefix_end = prompt_size;
|
||||||
// We need to look at all the tokens for the prefix.
|
// We need to look at all the tokens for the prefix.
|
||||||
runtime_config.prefill_tbatch_size = prompt_size;
|
runtime_config.prefill_tbatch_size = prompt_size;
|
||||||
|
} else {
|
||||||
|
prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(),
|
||||||
|
model.Info(), abs_pos, prompt_string);
|
||||||
|
prompt_size = prompt.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
if constexpr (kVerboseLogTokens) {
|
||||||
|
for (int i = 0; i < prompt_size; ++i) {
|
||||||
|
fprintf(stderr, "DDD TOKEN %3d: %6d\n", i, prompt[i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate until EOS or max_generated_tokens.
|
// Generate until EOS or max_generated_tokens.
|
||||||
|
|
|
||||||
|
|
@ -114,57 +114,96 @@ bool GemmaTokenizer::Decode(const std::vector<int>& ids,
|
||||||
return impl_->Decode(ids, detokenized);
|
return impl_->Decode(ids, detokenized);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
GemmaChatTemplate::GemmaChatTemplate(const GemmaTokenizer& tokenizer) {
|
||||||
const ModelInfo& info, size_t pos,
|
Init(tokenizer);
|
||||||
std::string& prompt) {
|
|
||||||
Wrap(info, pos, prompt);
|
|
||||||
|
|
||||||
std::vector<int> tokens;
|
|
||||||
HWY_ASSERT(tokenizer.Encode(prompt, &tokens));
|
|
||||||
// Both pre-trained and instruction-tuned require BOS as first token.
|
|
||||||
if (pos == 0) {
|
|
||||||
tokens.insert(tokens.begin(), BOS_ID);
|
|
||||||
}
|
|
||||||
|
|
||||||
// PaliGemma separator. The SEP token "\n" is always tokenized separately.
|
|
||||||
if (info.wrapping == PromptWrapping::PALIGEMMA
|
|
||||||
// || info.wrapping == PromptWrapping::GEMMA_VLM
|
|
||||||
) {
|
|
||||||
std::vector<int> sep_tokens;
|
|
||||||
HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens));
|
|
||||||
tokens.insert(tokens.end(), sep_tokens.begin(), sep_tokens.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
return tokens;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> WrapVLM(const GemmaTokenizer& tokenizer, const ModelInfo& info,
|
void GemmaChatTemplate::Init(const GemmaTokenizer& tokenizer) {
|
||||||
size_t pos, std::vector<int>& tokens,
|
sot_user_.reserve(3);
|
||||||
size_t image_batch_size, size_t max_image_batch_size) {
|
HWY_ASSERT(tokenizer.Encode("<start_of_turn>user\n", &sot_user_));
|
||||||
HWY_ASSERT(info.wrapping == PromptWrapping::GEMMA_VLM);
|
sot_model_.reserve(3);
|
||||||
size_t num_images = hwy::DivCeil(image_batch_size, max_image_batch_size);
|
HWY_ASSERT(tokenizer.Encode("<start_of_turn>model\n", &sot_model_));
|
||||||
|
eot_.reserve(2);
|
||||||
|
HWY_ASSERT(tokenizer.Encode("<end_of_turn>\n", &eot_));
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<int> sep_tokens;
|
std::vector<int> GemmaChatTemplate::Apply(size_t pos,
|
||||||
HWY_ASSERT(tokenizer.Encode("\n", &sep_tokens));
|
const std::vector<int>& ids) const {
|
||||||
|
HWY_ASSERT_M(!sot_user_.empty() && !sot_model_.empty() && !eot_.empty(),
|
||||||
std::string begin_image_prompt = "\n\n<start_of_image>";
|
"GemmaChatTemplate has not been initialized.");
|
||||||
std::vector<int> begin_image_tokens =
|
std::vector<int> out;
|
||||||
WrapAndTokenize(tokenizer, info, pos, begin_image_prompt);
|
out.reserve(eot_.size() +
|
||||||
|
sot_user_.size() +
|
||||||
std::string end_image_prompt = "<end_of_image>\n\n";
|
ids.size() +
|
||||||
std::vector<int> end_image_tokens =
|
eot_.size() +
|
||||||
WrapAndTokenize(tokenizer, info, pos, end_image_prompt);
|
sot_model_.size());
|
||||||
|
if (pos > 0) {
|
||||||
for (size_t i = 0; i < num_images; ++i) {
|
out.insert(out.cend(), eot_.cbegin(), eot_.cend());
|
||||||
tokens.insert(tokens.begin(), begin_image_tokens.begin(),
|
} else {
|
||||||
begin_image_tokens.end());
|
out.push_back(BOS_ID);
|
||||||
tokens.insert(tokens.begin() + begin_image_tokens.size(), image_batch_size,
|
|
||||||
-2);
|
|
||||||
tokens.insert(tokens.begin() + begin_image_tokens.size() + image_batch_size,
|
|
||||||
end_image_tokens.begin(), end_image_tokens.end());
|
|
||||||
}
|
}
|
||||||
|
out.insert(out.cend(), sot_user_.cbegin(), sot_user_.cend());
|
||||||
|
out.insert(out.cend(), ids.cbegin(), ids.cend());
|
||||||
|
out.insert(out.cend(), eot_.cbegin(), eot_.cend());
|
||||||
|
out.insert(out.cend(), sot_model_.cbegin(), sot_model_.cend());
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
return tokens;
|
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
||||||
|
const GemmaChatTemplate& chat_template,
|
||||||
|
const ModelInfo& info, size_t pos,
|
||||||
|
const std::string& prompt) {
|
||||||
|
std::vector<int> tokens;
|
||||||
|
HWY_ASSERT(tokenizer.Encode(prompt, &tokens));
|
||||||
|
switch (info.wrapping) {
|
||||||
|
case PromptWrapping::GEMMA_IT:
|
||||||
|
case PromptWrapping::GEMMA_VLM:
|
||||||
|
return chat_template.Apply(pos, tokens);
|
||||||
|
default:
|
||||||
|
if (pos == 0) {
|
||||||
|
tokens.insert(tokens.cbegin(), BOS_ID);
|
||||||
|
}
|
||||||
|
return tokens;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
||||||
|
const GemmaChatTemplate& chat_template,
|
||||||
|
const ModelInfo& info, size_t pos,
|
||||||
|
const std::string& prompt,
|
||||||
|
size_t image_batch_size) {
|
||||||
|
std::vector<int> text_part;
|
||||||
|
HWY_ASSERT(tokenizer.Encode(prompt, &text_part));
|
||||||
|
std::vector<int> tokens;
|
||||||
|
switch (info.wrapping) {
|
||||||
|
case PromptWrapping::PALIGEMMA: {
|
||||||
|
std::vector<int> sep;
|
||||||
|
HWY_ASSERT(tokenizer.Encode("\n", &sep));
|
||||||
|
tokens.reserve(image_batch_size + 1 + text_part.size() + sep.size());
|
||||||
|
tokens.resize(image_batch_size, 0);
|
||||||
|
HWY_ASSERT(pos == 0);
|
||||||
|
tokens.push_back(BOS_ID);
|
||||||
|
tokens.insert(tokens.cend(), text_part.cbegin(), text_part.cend());
|
||||||
|
tokens.insert(tokens.cend(), sep.cbegin(), sep.cend());
|
||||||
|
return tokens;
|
||||||
|
}
|
||||||
|
case PromptWrapping::GEMMA_VLM: {
|
||||||
|
std::vector<int> soi;
|
||||||
|
soi.reserve(2);
|
||||||
|
HWY_ASSERT(tokenizer.Encode("\n\n<start_of_image>", &soi));
|
||||||
|
std::vector<int> eoi;
|
||||||
|
eoi.reserve(2);
|
||||||
|
HWY_ASSERT(tokenizer.Encode("<end_of_image>\n\n", &eoi));
|
||||||
|
tokens.reserve(text_part.size() + soi.size() + image_batch_size + eoi.size());
|
||||||
|
tokens.insert(tokens.cend(), text_part.cbegin(), text_part.cend());
|
||||||
|
tokens.insert(tokens.cend(), soi.cbegin(), soi.cend());
|
||||||
|
tokens.insert(tokens.cend(), image_batch_size, -2);
|
||||||
|
tokens.insert(tokens.cend(), eoi.cbegin(), eoi.cend());
|
||||||
|
return chat_template.Apply(pos, tokens);
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
HWY_ASSERT_M(false, "Current variant does not support vision prompt.");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -54,13 +54,30 @@ class GemmaTokenizer {
|
||||||
std::unique_ptr<Impl> impl_;
|
std::unique_ptr<Impl> impl_;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
class GemmaChatTemplate {
|
||||||
const ModelInfo& info, size_t pos,
|
public:
|
||||||
std::string& prompt);
|
GemmaChatTemplate() = default;
|
||||||
|
explicit GemmaChatTemplate(const GemmaTokenizer& tokenizer);
|
||||||
|
|
||||||
std::vector<int> WrapVLM(const GemmaTokenizer& tokenizer, const ModelInfo& info,
|
void Init(const GemmaTokenizer& tokenizer);
|
||||||
size_t pos, std::vector<int>& tokens,
|
std::vector<int> Apply(size_t pos, const std::vector<int>& ids) const;
|
||||||
size_t image_batch_size, size_t max_image_batch_size);
|
|
||||||
|
private:
|
||||||
|
std::vector<int> sot_user_;
|
||||||
|
std::vector<int> sot_model_;
|
||||||
|
std::vector<int> eot_;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
||||||
|
const GemmaChatTemplate& chat_template,
|
||||||
|
const ModelInfo& info, size_t pos,
|
||||||
|
const std::string& prompt);
|
||||||
|
|
||||||
|
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
||||||
|
const GemmaChatTemplate& chat_template,
|
||||||
|
const ModelInfo& info, size_t pos,
|
||||||
|
const std::string& prompt,
|
||||||
|
size_t image_batch_size);
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue