add: VAETKI tokenizer implementation
This commit is contained in:
parent
ca85717886
commit
56c89a1216
|
|
@ -7671,6 +7671,46 @@ class VaetkiModel(TextModel):
|
|||
|
||||
_experts: list[dict[str, Tensor]] | None = None
|
||||
|
||||
def set_vocab(self):
|
||||
# VAETKI: hybrid tokenizer with SPM-style ▁ space markers + BPE rank-based merges + <0xXX> byte fallback
|
||||
# manual token loading because VAETKI doesn't fit standard BPE or SPM vocab loading
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
|
||||
vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab))
|
||||
|
||||
tokens: list[str] = []
|
||||
toktypes: list[int] = []
|
||||
|
||||
reverse_vocab = {id_: tok for tok, id_ in tokenizer.vocab.items()}
|
||||
added_vocab = tokenizer.get_added_vocab()
|
||||
added_tokens_decoder = tokenizer.added_tokens_decoder
|
||||
|
||||
for i in range(vocab_size):
|
||||
if i not in reverse_vocab:
|
||||
tokens.append(f"[PAD{i}]")
|
||||
toktypes.append(gguf.TokenType.UNUSED)
|
||||
else:
|
||||
token: str = reverse_vocab[i]
|
||||
if token in added_vocab:
|
||||
if not added_tokens_decoder[i].normalized:
|
||||
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
|
||||
if added_tokens_decoder[i].special or self.does_token_look_special(token):
|
||||
toktypes.append(gguf.TokenType.CONTROL)
|
||||
else:
|
||||
toktypes.append(gguf.TokenType.USER_DEFINED)
|
||||
else:
|
||||
toktypes.append(gguf.TokenType.NORMAL)
|
||||
tokens.append(token)
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("vaetki")
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
self.gguf_writer.add_add_space_prefix(False)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# Set rope_parameters for hybrid attention (transformers 5.0 format)
|
||||
|
|
|
|||
|
|
@ -76,6 +76,7 @@ extern "C" {
|
|||
LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram
|
||||
LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization
|
||||
LLAMA_VOCAB_TYPE_PLAMO2 = 6, // PLaMo-2 tokenizer based on Aho-Corasick with dynamic programming
|
||||
LLAMA_VOCAB_TYPE_VAETKI = 7, // VAETKI tokenizer based on rank-based BPE with SPM-style space markers
|
||||
};
|
||||
|
||||
enum llama_rope_type {
|
||||
|
|
|
|||
|
|
@ -1519,6 +1519,206 @@ private:
|
|||
const llm_tokenizer_plamo2 & tokenizer;
|
||||
};
|
||||
|
||||
//
|
||||
// VAETKI tokenizer
|
||||
// Hybrid tokenizer: SPM-style ▁ space markers + BPE rank-based merges + <0xXX> byte fallback
|
||||
//
|
||||
|
||||
struct llm_tokenizer_vaetki : llm_tokenizer {
|
||||
llm_tokenizer_vaetki(const llama_vocab & vocab) {
|
||||
GGML_ASSERT(vocab.get_type() == LLAMA_VOCAB_TYPE_VAETKI);
|
||||
}
|
||||
};
|
||||
|
||||
struct llm_tokenizer_vaetki_session {
|
||||
llm_tokenizer_vaetki_session(const llama_vocab & vocab)
|
||||
: vocab(vocab) {}
|
||||
|
||||
void tokenize(const std::string & text, std::vector<llama_token> & output) {
|
||||
int final_prev_index = -1;
|
||||
|
||||
// Normalize - replace all spaces with ▁ (U+2581)
|
||||
std::string normalized;
|
||||
normalized.reserve(text.size() * 3);
|
||||
for (size_t i = 0; i < text.size(); ) {
|
||||
if (text[i] == ' ') {
|
||||
normalized += "\xE2\x96\x81";
|
||||
i++;
|
||||
} else {
|
||||
size_t char_len = unicode_len_utf8(text[i]);
|
||||
normalized += text.substr(i, char_len);
|
||||
i += char_len;
|
||||
}
|
||||
}
|
||||
|
||||
// Split on ▁ boundaries, keeping ▁ with following text
|
||||
// "Hello▁World" -> ["Hello", "▁World"]
|
||||
// "Hello▁▁World" -> ["Hello", "▁▁World"]
|
||||
std::vector<std::string> word_collection;
|
||||
std::string current_word;
|
||||
const std::string escaped_space = "\xE2\x96\x81"; // ▁ (U+2581)
|
||||
|
||||
for (size_t i = 0; i < normalized.size(); ) {
|
||||
size_t char_len = unicode_len_utf8(normalized[i]);
|
||||
|
||||
if (char_len == 3 &&
|
||||
i + 2 < normalized.size() &&
|
||||
(unsigned char)normalized[i] == 0xE2 &&
|
||||
(unsigned char)normalized[i+1] == 0x96 &&
|
||||
(unsigned char)normalized[i+2] == 0x81) {
|
||||
if (!current_word.empty()) {
|
||||
word_collection.push_back(current_word);
|
||||
current_word.clear();
|
||||
}
|
||||
current_word = escaped_space;
|
||||
i += 3;
|
||||
} else {
|
||||
current_word += normalized.substr(i, char_len);
|
||||
i += char_len;
|
||||
}
|
||||
}
|
||||
if (!current_word.empty()) {
|
||||
word_collection.push_back(current_word);
|
||||
}
|
||||
|
||||
symbols_final.clear();
|
||||
|
||||
for (const auto & word : word_collection) {
|
||||
work_queue = llm_bigram_bpe::queue();
|
||||
symbols.clear();
|
||||
|
||||
int index = 0;
|
||||
size_t offset = 0;
|
||||
|
||||
// Check if word exists as a single token (ignore_merges behavior)
|
||||
if (vocab.text_to_token(word) != LLAMA_TOKEN_NULL) {
|
||||
symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()});
|
||||
offset = word.size();
|
||||
}
|
||||
|
||||
while (offset < word.size()) {
|
||||
llm_symbol sym;
|
||||
size_t char_len = std::min(word.size() - offset, (size_t) unicode_len_utf8(word[offset]));
|
||||
sym.text = word.c_str() + offset;
|
||||
sym.n = char_len;
|
||||
offset += sym.n;
|
||||
sym.prev = index - 1;
|
||||
sym.next = offset == word.size() ? -1 : index + 1;
|
||||
index++;
|
||||
symbols.emplace_back(sym);
|
||||
}
|
||||
for (int i = 1; i < (int) symbols.size(); ++i) {
|
||||
add_new_bigram(i - 1, i);
|
||||
}
|
||||
|
||||
// build token(s)
|
||||
while (!work_queue.empty()) {
|
||||
auto bigram = work_queue.pop_move();
|
||||
|
||||
auto & left_symbol = symbols[bigram.left];
|
||||
auto & right_symbol = symbols[bigram.right];
|
||||
|
||||
if (left_symbol.n == 0 || right_symbol.n == 0) {
|
||||
continue;
|
||||
}
|
||||
std::string left_token = std::string(left_symbol.text, left_symbol.n);
|
||||
std::string right_token = std::string(right_symbol.text, right_symbol.n);
|
||||
if (left_token + right_token != bigram.text) {
|
||||
continue; // Skip this bigram if it's outdated
|
||||
}
|
||||
|
||||
// merge the right sym into the left one
|
||||
left_symbol.n += right_symbol.n;
|
||||
right_symbol.n = 0;
|
||||
|
||||
// remove the right sym from the chain
|
||||
left_symbol.next = right_symbol.next;
|
||||
if (right_symbol.next >= 0) {
|
||||
symbols[right_symbol.next].prev = bigram.left;
|
||||
}
|
||||
|
||||
add_new_bigram(left_symbol.prev, bigram.left); // left side of current symbol
|
||||
add_new_bigram(bigram.left, left_symbol.next); // right side of current symbol
|
||||
}
|
||||
|
||||
// add the finished tokens to the final list keeping correct order for next and prev
|
||||
for (auto & sym : symbols) {
|
||||
if (sym.n > 0) {
|
||||
sym.prev = final_prev_index;
|
||||
sym.next = -1;
|
||||
if (final_prev_index != -1) {
|
||||
symbols_final[final_prev_index].next = symbols_final.size();
|
||||
}
|
||||
symbols_final.emplace_back(sym);
|
||||
final_prev_index = symbols_final.size() - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
symbols = symbols_final;
|
||||
|
||||
if (!symbols.empty()) {
|
||||
for (int i = 0; i != -1; i = symbols[i].next) {
|
||||
auto & symbol = symbols[i];
|
||||
if (symbol.n == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const std::string str = std::string(symbol.text, symbol.n);
|
||||
const auto token = vocab.text_to_token(str);
|
||||
|
||||
if (token == LLAMA_TOKEN_NULL) {
|
||||
// Byte fallback: use <0xXX> format
|
||||
for (auto j = str.begin(); j != str.end(); ++j) {
|
||||
char buf[8];
|
||||
snprintf(buf, sizeof(buf), "<0x%02X>", static_cast<uint8_t>(*j));
|
||||
std::string byte_str(buf);
|
||||
auto token_byte = vocab.text_to_token(byte_str);
|
||||
if (token_byte != LLAMA_TOKEN_NULL) {
|
||||
output.push_back(token_byte);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
output.push_back(token);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void add_new_bigram(int left, int right) {
|
||||
if (left == -1 || right == -1) {
|
||||
return;
|
||||
}
|
||||
std::string left_token = std::string(symbols[left].text, symbols[left].n);
|
||||
std::string right_token = std::string(symbols[right].text, symbols[right].n);
|
||||
|
||||
int rank_found = -1;
|
||||
|
||||
rank_found = vocab.find_bpe_rank(left_token, right_token);
|
||||
|
||||
if (rank_found < 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
llm_bigram_bpe bigram;
|
||||
|
||||
bigram.left = left;
|
||||
bigram.right = right;
|
||||
bigram.text = left_token + right_token;
|
||||
bigram.size = left_token.size() + right_token.size();
|
||||
bigram.rank = rank_found;
|
||||
|
||||
work_queue.push(bigram);
|
||||
}
|
||||
|
||||
const llama_vocab & vocab;
|
||||
|
||||
std::vector<llm_symbol> symbols;
|
||||
std::vector<llm_symbol> symbols_final;
|
||||
llm_bigram_bpe::queue work_queue;
|
||||
};
|
||||
|
||||
//
|
||||
// impl
|
||||
//
|
||||
|
|
@ -1832,6 +2032,39 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||
special_sep_id = LLAMA_TOKEN_NULL;
|
||||
special_pad_id = 3; // <|plamo:pad|>
|
||||
special_mask_id = LLAMA_TOKEN_NULL;
|
||||
} else if (tokenizer_model == "vaetki") {
|
||||
type = LLAMA_VOCAB_TYPE_VAETKI;
|
||||
|
||||
// read bpe merges and populate bpe ranks (same as gpt2)
|
||||
const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str());
|
||||
if (merges_keyidx == -1) {
|
||||
throw std::runtime_error("cannot find tokenizer merges in model file\n");
|
||||
}
|
||||
|
||||
const int n_merges = gguf_get_arr_n(ctx, merges_keyidx);
|
||||
for (int i = 0; i < n_merges; i++) {
|
||||
const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i);
|
||||
|
||||
std::string first;
|
||||
std::string second;
|
||||
|
||||
const size_t pos = word.find(' ', 1);
|
||||
|
||||
if (pos != std::string::npos) {
|
||||
first = word.substr(0, pos);
|
||||
second = word.substr(pos + 1);
|
||||
}
|
||||
|
||||
bpe_ranks.emplace(std::make_pair(first, second), i);
|
||||
}
|
||||
|
||||
// VAETKI default special tokens (will be overridden by model config)
|
||||
special_bos_id = LLAMA_TOKEN_NULL;
|
||||
special_eos_id = LLAMA_TOKEN_NULL;
|
||||
special_unk_id = LLAMA_TOKEN_NULL;
|
||||
special_sep_id = LLAMA_TOKEN_NULL;
|
||||
special_pad_id = LLAMA_TOKEN_NULL;
|
||||
special_mask_id = LLAMA_TOKEN_NULL;
|
||||
} else {
|
||||
throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
|
||||
}
|
||||
|
|
@ -2627,6 +2860,7 @@ std::string llama_vocab::impl::type_name() const{
|
|||
case LLAMA_VOCAB_TYPE_UGM: return "UGM";
|
||||
case LLAMA_VOCAB_TYPE_RWKV: return "RWKV";
|
||||
case LLAMA_VOCAB_TYPE_PLAMO2: return "PLaMo2";
|
||||
case LLAMA_VOCAB_TYPE_VAETKI: return "VAETKI";
|
||||
default: return "unknown";
|
||||
}
|
||||
}
|
||||
|
|
@ -2671,7 +2905,9 @@ uint8_t llama_vocab::impl::token_to_byte(llama_token id) const {
|
|||
const auto & token_data = id_to_token.at(id);
|
||||
switch (get_type()) {
|
||||
case LLAMA_VOCAB_TYPE_SPM:
|
||||
case LLAMA_VOCAB_TYPE_UGM: {
|
||||
case LLAMA_VOCAB_TYPE_UGM:
|
||||
case LLAMA_VOCAB_TYPE_VAETKI: {
|
||||
// <0xXX> format
|
||||
auto buf = token_data.text.substr(3, 2);
|
||||
return strtol(buf.c_str(), NULL, 16);
|
||||
}
|
||||
|
|
@ -2713,6 +2949,9 @@ void llama_vocab::impl::init_tokenizer(enum llama_vocab_type type) {
|
|||
case LLAMA_VOCAB_TYPE_PLAMO2:
|
||||
tokenizer = std::make_unique<llm_tokenizer_plamo2>(vocab);
|
||||
break;
|
||||
case LLAMA_VOCAB_TYPE_VAETKI:
|
||||
tokenizer = std::make_unique<llm_tokenizer_vaetki>(vocab);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("unsupported vocab type");
|
||||
}
|
||||
|
|
@ -3072,6 +3311,41 @@ std::vector<llama_token> llama_vocab::impl::tokenize(
|
|||
}
|
||||
}
|
||||
} break;
|
||||
case LLAMA_VOCAB_TYPE_VAETKI:
|
||||
{
|
||||
llm_tokenizer_vaetki_session session(vocab);
|
||||
|
||||
if (add_special && add_bos) {
|
||||
GGML_ASSERT(special_bos_id != LLAMA_TOKEN_NULL);
|
||||
output.push_back(special_bos_id);
|
||||
}
|
||||
|
||||
for (const auto & fragment : fragment_buffer) {
|
||||
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
||||
std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
|
||||
|
||||
#ifdef PRETOKENIZERDEBUG
|
||||
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
|
||||
#endif
|
||||
|
||||
session.tokenize(text, output);
|
||||
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
|
||||
output.push_back(fragment.token);
|
||||
}
|
||||
}
|
||||
|
||||
if (add_special && add_bos && output.size() >= 2 && output[1] == special_bos_id) {
|
||||
LLAMA_LOG_WARN(
|
||||
"%s: Added a BOS token to the prompt as specified by the model but the prompt "
|
||||
"also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
|
||||
"Are you sure this is what you want?\n", __FUNCTION__);
|
||||
}
|
||||
|
||||
if (add_special && add_eos) {
|
||||
GGML_ASSERT(special_eos_id != LLAMA_TOKEN_NULL);
|
||||
output.push_back(special_eos_id);
|
||||
}
|
||||
} break;
|
||||
case LLAMA_VOCAB_TYPE_NONE:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
|
@ -3120,7 +3394,8 @@ int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t
|
|||
switch (get_type()) {
|
||||
case LLAMA_VOCAB_TYPE_WPM:
|
||||
case LLAMA_VOCAB_TYPE_SPM:
|
||||
case LLAMA_VOCAB_TYPE_UGM: {
|
||||
case LLAMA_VOCAB_TYPE_UGM:
|
||||
case LLAMA_VOCAB_TYPE_VAETKI: {
|
||||
// NOTE: we accept all unsupported token types,
|
||||
// suppressing them like CONTROL tokens.
|
||||
if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
|
||||
|
|
@ -3421,8 +3696,9 @@ llama_token llama_vocab::byte_to_token(uint8_t ch) const {
|
|||
case LLAMA_VOCAB_TYPE_BPE: {
|
||||
return pimpl->token_to_id.at(unicode_byte_to_utf8(ch));
|
||||
}
|
||||
case LLAMA_VOCAB_TYPE_PLAMO2: {
|
||||
// PLaMo-2 uses byte tokens in format <0xXX>
|
||||
case LLAMA_VOCAB_TYPE_PLAMO2:
|
||||
case LLAMA_VOCAB_TYPE_VAETKI: {
|
||||
// PLaMo-2/VAETKI uses byte tokens in format <0xXX>
|
||||
char hex_str[8];
|
||||
snprintf(hex_str, sizeof(hex_str), "<0x%02X>", ch);
|
||||
return pimpl->token_to_id.at(hex_str);
|
||||
|
|
|
|||
Loading…
Reference in New Issue