mirror of https://github.com/google/gemma.cpp.git
Improve `GemmaChatTemplate` to handle vision prompt wrapping
This commit is contained in:
parent
c39295f497
commit
cc2e14e654
|
|
@ -121,6 +121,11 @@ void GemmaChatTemplate::Init(const GemmaTokenizer& tokenizer) {
|
||||||
HWY_ASSERT(tokenizer.Encode("<start_of_turn>model\n", &sot_model_));
|
HWY_ASSERT(tokenizer.Encode("<start_of_turn>model\n", &sot_model_));
|
||||||
eot_.reserve(2);
|
eot_.reserve(2);
|
||||||
HWY_ASSERT(tokenizer.Encode("<end_of_turn>\n", &eot_));
|
HWY_ASSERT(tokenizer.Encode("<end_of_turn>\n", &eot_));
|
||||||
|
HWY_ASSERT(tokenizer.Encode("\n", &pali_sep_));
|
||||||
|
vlm_soi_.reserve(2);
|
||||||
|
HWY_ASSERT(tokenizer.Encode("\n\n<start_of_image>", &vlm_soi_));
|
||||||
|
vlm_eoi_.reserve(2);
|
||||||
|
HWY_ASSERT(tokenizer.Encode("<end_of_image>\n\n", &vlm_eoi_));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> GemmaChatTemplate::Apply(size_t pos,
|
std::vector<int> GemmaChatTemplate::Apply(size_t pos,
|
||||||
|
|
@ -145,6 +150,33 @@ std::vector<int> GemmaChatTemplate::Apply(size_t pos,
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<int> GemmaChatTemplate::WrapPali(const std::vector<int>& text_part,
|
||||||
|
size_t image_batch_size) const {
|
||||||
|
HWY_ASSERT_M(!pali_sep_.empty(),
|
||||||
|
"GemmaChatTemplate has not been initialized.");
|
||||||
|
std::vector<int> out;
|
||||||
|
out.reserve(image_batch_size + 1 + text_part.size() + pali_sep_.size());
|
||||||
|
out.resize(image_batch_size, 0);
|
||||||
|
out.push_back(BOS_ID);
|
||||||
|
out.insert(out.cend(), text_part.cbegin(), text_part.cend());
|
||||||
|
out.insert(out.cend(), pali_sep_.cbegin(), pali_sep_.cend());
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> GemmaChatTemplate::WrapVLM(const std::vector<int>& text_part,
|
||||||
|
size_t image_batch_size) const {
|
||||||
|
HWY_ASSERT_M(!vlm_soi_.empty() && !vlm_eoi_.empty(),
|
||||||
|
"GemmaChatTemplate has not been initialized.");
|
||||||
|
std::vector<int> out;
|
||||||
|
out.reserve(
|
||||||
|
text_part.size() + vlm_soi_.size() + image_batch_size + vlm_eoi_.size());
|
||||||
|
out.insert(out.cend(), text_part.cbegin(), text_part.cend());
|
||||||
|
out.insert(out.cend(), vlm_soi_.cbegin(), vlm_soi_.cend());
|
||||||
|
out.insert(out.cend(), image_batch_size, -2);
|
||||||
|
out.insert(out.cend(), vlm_eoi_.cbegin(), vlm_eoi_.cend());
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
||||||
const GemmaChatTemplate& chat_template,
|
const GemmaChatTemplate& chat_template,
|
||||||
const ModelInfo& info, size_t pos,
|
const ModelInfo& info, size_t pos,
|
||||||
|
|
@ -170,33 +202,13 @@ std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
||||||
size_t image_batch_size) {
|
size_t image_batch_size) {
|
||||||
std::vector<int> text_part;
|
std::vector<int> text_part;
|
||||||
HWY_ASSERT(tokenizer.Encode(prompt, &text_part));
|
HWY_ASSERT(tokenizer.Encode(prompt, &text_part));
|
||||||
std::vector<int> tokens;
|
|
||||||
switch (info.wrapping) {
|
switch (info.wrapping) {
|
||||||
case PromptWrapping::PALIGEMMA: {
|
case PromptWrapping::PALIGEMMA:
|
||||||
std::vector<int> sep;
|
|
||||||
HWY_ASSERT(tokenizer.Encode("\n", &sep));
|
|
||||||
tokens.reserve(image_batch_size + 1 + text_part.size() + sep.size());
|
|
||||||
tokens.resize(image_batch_size, 0);
|
|
||||||
HWY_ASSERT(pos == 0);
|
HWY_ASSERT(pos == 0);
|
||||||
tokens.push_back(BOS_ID);
|
return chat_template.WrapPali(text_part, image_batch_size);
|
||||||
tokens.insert(tokens.cend(), text_part.cbegin(), text_part.cend());
|
case PromptWrapping::GEMMA_VLM:
|
||||||
tokens.insert(tokens.cend(), sep.cbegin(), sep.cend());
|
return chat_template.Apply(pos, chat_template.WrapVLM(text_part,
|
||||||
return tokens;
|
image_batch_size));
|
||||||
}
|
|
||||||
case PromptWrapping::GEMMA_VLM: {
|
|
||||||
std::vector<int> soi;
|
|
||||||
soi.reserve(2);
|
|
||||||
HWY_ASSERT(tokenizer.Encode("\n\n<start_of_image>", &soi));
|
|
||||||
std::vector<int> eoi;
|
|
||||||
eoi.reserve(2);
|
|
||||||
HWY_ASSERT(tokenizer.Encode("<end_of_image>\n\n", &eoi));
|
|
||||||
tokens.reserve(text_part.size() + soi.size() + image_batch_size + eoi.size());
|
|
||||||
tokens.insert(tokens.cend(), text_part.cbegin(), text_part.cend());
|
|
||||||
tokens.insert(tokens.cend(), soi.cbegin(), soi.cend());
|
|
||||||
tokens.insert(tokens.cend(), image_batch_size, -2);
|
|
||||||
tokens.insert(tokens.cend(), eoi.cbegin(), eoi.cend());
|
|
||||||
return chat_template.Apply(pos, tokens);
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
HWY_ASSERT_M(false, "Current variant does not support vision prompt.");
|
HWY_ASSERT_M(false, "Current variant does not support vision prompt.");
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -63,11 +63,18 @@ class GemmaChatTemplate {
|
||||||
|
|
||||||
void Init(const GemmaTokenizer& tokenizer);
|
void Init(const GemmaTokenizer& tokenizer);
|
||||||
std::vector<int> Apply(size_t pos, const std::vector<int>& ids) const;
|
std::vector<int> Apply(size_t pos, const std::vector<int>& ids) const;
|
||||||
|
std::vector<int> WrapPali(const std::vector<int>& text_part,
|
||||||
|
size_t image_batch_size) const;
|
||||||
|
std::vector<int> WrapVLM(const std::vector<int>& text_part,
|
||||||
|
size_t image_batch_size) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<int> sot_user_;
|
std::vector<int> sot_user_;
|
||||||
std::vector<int> sot_model_;
|
std::vector<int> sot_model_;
|
||||||
std::vector<int> eot_;
|
std::vector<int> eot_;
|
||||||
|
std::vector<int> pali_sep_;
|
||||||
|
std::vector<int> vlm_soi_;
|
||||||
|
std::vector<int> vlm_eoi_;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
std::vector<int> WrapAndTokenize(const GemmaTokenizer& tokenizer,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue