vocab: fix Gemma4 tokenizer (#21343)
* seems to work * fix case with new line Co-authored-by: sayap <sokann@gmail.com> * gemma 4: fix pre tok regex --------- Co-authored-by: Xuan Son Nguyen <son@huggingface.co> Co-authored-by: sayap <sokann@gmail.com>
This commit is contained in:
parent
0c58ba3365
commit
b069b10ab4
|
|
@ -7464,9 +7464,6 @@ class Gemma4Model(Gemma3Model):
|
|||
|
||||
assert len(tokens) == vocab.vocab_size
|
||||
|
||||
# TODO @ngxson : there are some known (rare) issues with the tokenizer during development
|
||||
# but I don't have time to dive into them right now;
|
||||
# using a dedicated tokenizer name so that we can fix later without re-converting GGUF
|
||||
self.gguf_writer.add_tokenizer_model("gemma4")
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_scores(scores)
|
||||
|
|
|
|||
|
|
@ -493,6 +493,16 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
|||
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?(?:\\p{L}\\p{M}*(?: \\p{L}\\p{M}*)*)+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]?|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+",
|
||||
};
|
||||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_GEMMA4:
|
||||
// Gemma4 uses SPM-style BPE: spaces are replaced with ▁ by the
|
||||
// normalizer, then BPE merges run on the whole text without
|
||||
// word-level pre-splitting. We only need to split on newlines
|
||||
// since BPE merge lookup asserts no newlines in tokens.
|
||||
regex_exprs = {
|
||||
"[^\\n]+|[\\n]+",
|
||||
};
|
||||
byte_encode = false; // uses raw UTF-8, not GPT-2 byte encoding
|
||||
break;
|
||||
default:
|
||||
// default regex for BPE tokenization pre-processing
|
||||
regex_exprs = {
|
||||
|
|
@ -506,6 +516,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
|||
}
|
||||
|
||||
std::vector<std::string> regex_exprs;
|
||||
bool byte_encode = true; // GPT-2 byte encoding; false for SPM-style BPE (raw UTF-8)
|
||||
};
|
||||
|
||||
struct llm_tokenizer_bpe_session {
|
||||
|
|
@ -550,9 +561,10 @@ struct llm_tokenizer_bpe_session {
|
|||
|
||||
void tokenize(const std::string & text, std::vector<llama_token> & output) {
|
||||
int final_prev_index = -1;
|
||||
const auto word_collection = unicode_regex_split(text, tokenizer.regex_exprs);
|
||||
const auto word_collection = unicode_regex_split(text, tokenizer.regex_exprs, tokenizer.byte_encode);
|
||||
|
||||
symbols_final.clear();
|
||||
auto tok_pre = vocab.get_pre_type();
|
||||
|
||||
for (const auto & word : word_collection) {
|
||||
work_queue = llm_bigram_bpe::queue();
|
||||
|
|
@ -565,6 +577,13 @@ struct llm_tokenizer_bpe_session {
|
|||
if (vocab.get_ignore_merges() && vocab.text_to_token(word) != LLAMA_TOKEN_NULL) {
|
||||
symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()});
|
||||
offset = word.size();
|
||||
} else if (tok_pre == LLAMA_VOCAB_PRE_TYPE_GEMMA4 && word.find_first_not_of('\n') == std::string::npos) {
|
||||
// fix for gemma 4, ref: https://github.com/ggml-org/llama.cpp/pull/21343
|
||||
auto tok = vocab.text_to_token(word);
|
||||
if (tok != LLAMA_TOKEN_NULL) {
|
||||
symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()});
|
||||
offset = word.size();
|
||||
}
|
||||
}
|
||||
|
||||
while (offset < word.size()) {
|
||||
|
|
@ -1864,7 +1883,31 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||
special_pad_id = 3; // <|plamo:pad|>
|
||||
special_mask_id = LLAMA_TOKEN_NULL;
|
||||
} else if (tokenizer_model == "gemma4") {
|
||||
type = LLAMA_VOCAB_TYPE_SPM;
|
||||
type = LLAMA_VOCAB_TYPE_BPE;
|
||||
|
||||
// read bpe merges and populate bpe ranks
|
||||
const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str());
|
||||
if (merges_keyidx == -1) {
|
||||
throw std::runtime_error("cannot find tokenizer merges in model file\n");
|
||||
}
|
||||
{
|
||||
const int n_merges = gguf_get_arr_n(ctx, merges_keyidx);
|
||||
for (int i = 0; i < n_merges; i++) {
|
||||
const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i);
|
||||
|
||||
std::string first;
|
||||
std::string second;
|
||||
|
||||
const size_t pos = word.find(' ', 1);
|
||||
|
||||
if (pos != std::string::npos) {
|
||||
first = word.substr(0, pos);
|
||||
second = word.substr(pos + 1);
|
||||
}
|
||||
|
||||
bpe_ranks.emplace(std::make_pair(first, second), i);
|
||||
}
|
||||
}
|
||||
|
||||
// default special tokens (to be read from GGUF)
|
||||
special_bos_id = LLAMA_TOKEN_NULL;
|
||||
|
|
@ -1874,7 +1917,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||
special_pad_id = LLAMA_TOKEN_NULL;
|
||||
special_mask_id = LLAMA_TOKEN_NULL;
|
||||
|
||||
tokenizer_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||
tokenizer_pre = "gemma4";
|
||||
} else {
|
||||
throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
|
||||
}
|
||||
|
|
@ -1882,6 +1925,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||
// for now, only BPE models have pre-tokenizers
|
||||
if (type == LLAMA_VOCAB_TYPE_BPE) {
|
||||
add_space_prefix = false;
|
||||
escape_whitespaces = false;
|
||||
clean_spaces = true;
|
||||
if (tokenizer_pre.empty()) {
|
||||
LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__);
|
||||
|
|
@ -1948,6 +1992,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||
} else if (
|
||||
tokenizer_pre == "jais-2") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_JAIS2;
|
||||
} else if (
|
||||
tokenizer_pre == "gemma4") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_GEMMA4;
|
||||
escape_whitespaces = true;
|
||||
} else if (
|
||||
tokenizer_pre == "jina-v1-en" ||
|
||||
tokenizer_pre == "jina-v2-code" ||
|
||||
|
|
@ -3045,6 +3093,10 @@ std::vector<llama_token> llama_vocab::impl::tokenize(
|
|||
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
||||
std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
|
||||
|
||||
if (escape_whitespaces) {
|
||||
llama_escape_whitespace(text);
|
||||
}
|
||||
|
||||
#ifdef PRETOKENIZERDEBUG
|
||||
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
|
||||
#endif
|
||||
|
|
@ -3224,6 +3276,12 @@ int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t
|
|||
return _try_copy(token_text.data(), token_text.size());
|
||||
}
|
||||
if (attr & LLAMA_TOKEN_ATTR_NORMAL) {
|
||||
if (escape_whitespaces) {
|
||||
// SPM-style BPE: tokens contain ▁ for spaces
|
||||
std::string result = token_text;
|
||||
llama_unescape_whitespace(result);
|
||||
return _try_copy(result.data(), result.size());
|
||||
}
|
||||
std::string result = llama_decode_text(token_text);
|
||||
return _try_copy(result.data(), result.size());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -58,6 +58,7 @@ enum llama_vocab_pre_type {
|
|||
LLAMA_VOCAB_PRE_TYPE_TINY_AYA = 47,
|
||||
LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM = 48,
|
||||
LLAMA_VOCAB_PRE_TYPE_JAIS2 = 49,
|
||||
LLAMA_VOCAB_PRE_TYPE_GEMMA4 = 50,
|
||||
};
|
||||
|
||||
struct LLM_KV;
|
||||
|
|
|
|||
|
|
@ -912,7 +912,7 @@ bool unicode_cpt_is_han(uint32_t cpt) {
|
|||
return false;
|
||||
}
|
||||
|
||||
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
|
||||
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs, bool byte_encode) {
|
||||
// unicode categories
|
||||
static const std::map<std::string, int> k_ucat_enum = {
|
||||
{ "\\p{N}", unicode_cpt_flags::NUMBER },
|
||||
|
|
@ -1099,5 +1099,9 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
|
|||
start += offset;
|
||||
}
|
||||
|
||||
return unicode_byte_encoding_process(bpe_words);
|
||||
if (byte_encode) {
|
||||
return unicode_byte_encoding_process(bpe_words);
|
||||
}
|
||||
|
||||
return bpe_words;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -108,4 +108,4 @@ uint32_t unicode_tolower(uint32_t cpt);
|
|||
|
||||
bool unicode_cpt_is_han(uint32_t cpt);
|
||||
|
||||
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs);
|
||||
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs, bool byte_encode = true);
|
||||
|
|
|
|||
Loading…
Reference in New Issue