add: VAETKI tokenizer implementation

This commit is contained in:
suhyun-hwang 2026-01-14 22:19:06 +09:00
parent ca85717886
commit 56c89a1216
3 changed files with 321 additions and 4 deletions

View File

@ -7671,6 +7671,46 @@ class VaetkiModel(TextModel):
_experts: list[dict[str, Tensor]] | None = None
def set_vocab(self):
# VAETKI: hybrid tokenizer with SPM-style ▁ space markers + BPE rank-based merges + <0xXX> byte fallback
# manual token loading because VAETKI doesn't fit standard BPE or SPM vocab loading
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab))
tokens: list[str] = []
toktypes: list[int] = []
reverse_vocab = {id_: tok for tok, id_ in tokenizer.vocab.items()}
added_vocab = tokenizer.get_added_vocab()
added_tokens_decoder = tokenizer.added_tokens_decoder
for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.UNUSED)
else:
token: str = reverse_vocab[i]
if token in added_vocab:
if not added_tokens_decoder[i].normalized:
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
if added_tokens_decoder[i].special or self.does_token_look_special(token):
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.USER_DEFINED)
else:
toktypes.append(gguf.TokenType.NORMAL)
tokens.append(token)
self.gguf_writer.add_tokenizer_model("vaetki")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
special_vocab.add_to_gguf(self.gguf_writer)
self.gguf_writer.add_add_space_prefix(False)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Set rope_parameters for hybrid attention (transformers 5.0 format)

View File

@ -76,6 +76,7 @@ extern "C" {
LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram
LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization
LLAMA_VOCAB_TYPE_PLAMO2 = 6, // PLaMo-2 tokenizer based on Aho-Corasick with dynamic programming
LLAMA_VOCAB_TYPE_VAETKI = 7, // VAETKI tokenizer based on rank-based BPE with SPM-style space markers
};
enum llama_rope_type {

View File

@ -1519,6 +1519,206 @@ private:
const llm_tokenizer_plamo2 & tokenizer;
};
//
// VAETKI tokenizer
// Hybrid tokenizer: SPM-style ▁ space markers + BPE rank-based merges + <0xXX> byte fallback
//
struct llm_tokenizer_vaetki : llm_tokenizer {
llm_tokenizer_vaetki(const llama_vocab & vocab) {
GGML_ASSERT(vocab.get_type() == LLAMA_VOCAB_TYPE_VAETKI);
}
};
struct llm_tokenizer_vaetki_session {
llm_tokenizer_vaetki_session(const llama_vocab & vocab)
: vocab(vocab) {}
void tokenize(const std::string & text, std::vector<llama_token> & output) {
int final_prev_index = -1;
// Normalize - replace all spaces with ▁ (U+2581)
std::string normalized;
normalized.reserve(text.size() * 3);
for (size_t i = 0; i < text.size(); ) {
if (text[i] == ' ') {
normalized += "\xE2\x96\x81";
i++;
} else {
size_t char_len = unicode_len_utf8(text[i]);
normalized += text.substr(i, char_len);
i += char_len;
}
}
// Split on ▁ boundaries, keeping ▁ with following text
// "Hello▁World" -> ["Hello", "▁World"]
// "Hello▁▁World" -> ["Hello", "▁▁World"]
std::vector<std::string> word_collection;
std::string current_word;
const std::string escaped_space = "\xE2\x96\x81"; // ▁ (U+2581)
for (size_t i = 0; i < normalized.size(); ) {
size_t char_len = unicode_len_utf8(normalized[i]);
if (char_len == 3 &&
i + 2 < normalized.size() &&
(unsigned char)normalized[i] == 0xE2 &&
(unsigned char)normalized[i+1] == 0x96 &&
(unsigned char)normalized[i+2] == 0x81) {
if (!current_word.empty()) {
word_collection.push_back(current_word);
current_word.clear();
}
current_word = escaped_space;
i += 3;
} else {
current_word += normalized.substr(i, char_len);
i += char_len;
}
}
if (!current_word.empty()) {
word_collection.push_back(current_word);
}
symbols_final.clear();
for (const auto & word : word_collection) {
work_queue = llm_bigram_bpe::queue();
symbols.clear();
int index = 0;
size_t offset = 0;
// Check if word exists as a single token (ignore_merges behavior)
if (vocab.text_to_token(word) != LLAMA_TOKEN_NULL) {
symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()});
offset = word.size();
}
while (offset < word.size()) {
llm_symbol sym;
size_t char_len = std::min(word.size() - offset, (size_t) unicode_len_utf8(word[offset]));
sym.text = word.c_str() + offset;
sym.n = char_len;
offset += sym.n;
sym.prev = index - 1;
sym.next = offset == word.size() ? -1 : index + 1;
index++;
symbols.emplace_back(sym);
}
for (int i = 1; i < (int) symbols.size(); ++i) {
add_new_bigram(i - 1, i);
}
// build token(s)
while (!work_queue.empty()) {
auto bigram = work_queue.pop_move();
auto & left_symbol = symbols[bigram.left];
auto & right_symbol = symbols[bigram.right];
if (left_symbol.n == 0 || right_symbol.n == 0) {
continue;
}
std::string left_token = std::string(left_symbol.text, left_symbol.n);
std::string right_token = std::string(right_symbol.text, right_symbol.n);
if (left_token + right_token != bigram.text) {
continue; // Skip this bigram if it's outdated
}
// merge the right sym into the left one
left_symbol.n += right_symbol.n;
right_symbol.n = 0;
// remove the right sym from the chain
left_symbol.next = right_symbol.next;
if (right_symbol.next >= 0) {
symbols[right_symbol.next].prev = bigram.left;
}
add_new_bigram(left_symbol.prev, bigram.left); // left side of current symbol
add_new_bigram(bigram.left, left_symbol.next); // right side of current symbol
}
// add the finished tokens to the final list keeping correct order for next and prev
for (auto & sym : symbols) {
if (sym.n > 0) {
sym.prev = final_prev_index;
sym.next = -1;
if (final_prev_index != -1) {
symbols_final[final_prev_index].next = symbols_final.size();
}
symbols_final.emplace_back(sym);
final_prev_index = symbols_final.size() - 1;
}
}
}
symbols = symbols_final;
if (!symbols.empty()) {
for (int i = 0; i != -1; i = symbols[i].next) {
auto & symbol = symbols[i];
if (symbol.n == 0) {
continue;
}
const std::string str = std::string(symbol.text, symbol.n);
const auto token = vocab.text_to_token(str);
if (token == LLAMA_TOKEN_NULL) {
// Byte fallback: use <0xXX> format
for (auto j = str.begin(); j != str.end(); ++j) {
char buf[8];
snprintf(buf, sizeof(buf), "<0x%02X>", static_cast<uint8_t>(*j));
std::string byte_str(buf);
auto token_byte = vocab.text_to_token(byte_str);
if (token_byte != LLAMA_TOKEN_NULL) {
output.push_back(token_byte);
}
}
} else {
output.push_back(token);
}
}
}
}
private:
void add_new_bigram(int left, int right) {
if (left == -1 || right == -1) {
return;
}
std::string left_token = std::string(symbols[left].text, symbols[left].n);
std::string right_token = std::string(symbols[right].text, symbols[right].n);
int rank_found = -1;
rank_found = vocab.find_bpe_rank(left_token, right_token);
if (rank_found < 0) {
return;
}
llm_bigram_bpe bigram;
bigram.left = left;
bigram.right = right;
bigram.text = left_token + right_token;
bigram.size = left_token.size() + right_token.size();
bigram.rank = rank_found;
work_queue.push(bigram);
}
const llama_vocab & vocab;
std::vector<llm_symbol> symbols;
std::vector<llm_symbol> symbols_final;
llm_bigram_bpe::queue work_queue;
};
//
// impl
//
@ -1832,6 +2032,39 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
special_sep_id = LLAMA_TOKEN_NULL;
special_pad_id = 3; // <|plamo:pad|>
special_mask_id = LLAMA_TOKEN_NULL;
} else if (tokenizer_model == "vaetki") {
type = LLAMA_VOCAB_TYPE_VAETKI;
// read bpe merges and populate bpe ranks (same as gpt2)
const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str());
if (merges_keyidx == -1) {
throw std::runtime_error("cannot find tokenizer merges in model file\n");
}
const int n_merges = gguf_get_arr_n(ctx, merges_keyidx);
for (int i = 0; i < n_merges; i++) {
const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i);
std::string first;
std::string second;
const size_t pos = word.find(' ', 1);
if (pos != std::string::npos) {
first = word.substr(0, pos);
second = word.substr(pos + 1);
}
bpe_ranks.emplace(std::make_pair(first, second), i);
}
// VAETKI default special tokens (will be overridden by model config)
special_bos_id = LLAMA_TOKEN_NULL;
special_eos_id = LLAMA_TOKEN_NULL;
special_unk_id = LLAMA_TOKEN_NULL;
special_sep_id = LLAMA_TOKEN_NULL;
special_pad_id = LLAMA_TOKEN_NULL;
special_mask_id = LLAMA_TOKEN_NULL;
} else {
throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str()));
}
@ -2627,6 +2860,7 @@ std::string llama_vocab::impl::type_name() const{
case LLAMA_VOCAB_TYPE_UGM: return "UGM";
case LLAMA_VOCAB_TYPE_RWKV: return "RWKV";
case LLAMA_VOCAB_TYPE_PLAMO2: return "PLaMo2";
case LLAMA_VOCAB_TYPE_VAETKI: return "VAETKI";
default: return "unknown";
}
}
@ -2671,7 +2905,9 @@ uint8_t llama_vocab::impl::token_to_byte(llama_token id) const {
const auto & token_data = id_to_token.at(id);
switch (get_type()) {
case LLAMA_VOCAB_TYPE_SPM:
case LLAMA_VOCAB_TYPE_UGM: {
case LLAMA_VOCAB_TYPE_UGM:
case LLAMA_VOCAB_TYPE_VAETKI: {
// <0xXX> format
auto buf = token_data.text.substr(3, 2);
return strtol(buf.c_str(), NULL, 16);
}
@ -2713,6 +2949,9 @@ void llama_vocab::impl::init_tokenizer(enum llama_vocab_type type) {
case LLAMA_VOCAB_TYPE_PLAMO2:
tokenizer = std::make_unique<llm_tokenizer_plamo2>(vocab);
break;
case LLAMA_VOCAB_TYPE_VAETKI:
tokenizer = std::make_unique<llm_tokenizer_vaetki>(vocab);
break;
default:
GGML_ABORT("unsupported vocab type");
}
@ -3072,6 +3311,41 @@ std::vector<llama_token> llama_vocab::impl::tokenize(
}
}
} break;
case LLAMA_VOCAB_TYPE_VAETKI:
{
llm_tokenizer_vaetki_session session(vocab);
if (add_special && add_bos) {
GGML_ASSERT(special_bos_id != LLAMA_TOKEN_NULL);
output.push_back(special_bos_id);
}
for (const auto & fragment : fragment_buffer) {
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
std::string text = fragment.raw_text.substr(fragment.offset, fragment.length);
#ifdef PRETOKENIZERDEBUG
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str());
#endif
session.tokenize(text, output);
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
output.push_back(fragment.token);
}
}
if (add_special && add_bos && output.size() >= 2 && output[1] == special_bos_id) {
LLAMA_LOG_WARN(
"%s: Added a BOS token to the prompt as specified by the model but the prompt "
"also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
"Are you sure this is what you want?\n", __FUNCTION__);
}
if (add_special && add_eos) {
GGML_ASSERT(special_eos_id != LLAMA_TOKEN_NULL);
output.push_back(special_eos_id);
}
} break;
case LLAMA_VOCAB_TYPE_NONE:
GGML_ABORT("fatal error");
}
@ -3120,7 +3394,8 @@ int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t
switch (get_type()) {
case LLAMA_VOCAB_TYPE_WPM:
case LLAMA_VOCAB_TYPE_SPM:
case LLAMA_VOCAB_TYPE_UGM: {
case LLAMA_VOCAB_TYPE_UGM:
case LLAMA_VOCAB_TYPE_VAETKI: {
// NOTE: we accept all unsupported token types,
// suppressing them like CONTROL tokens.
if (attr & (attr_special | LLAMA_TOKEN_ATTR_USER_DEFINED)) {
@ -3421,8 +3696,9 @@ llama_token llama_vocab::byte_to_token(uint8_t ch) const {
case LLAMA_VOCAB_TYPE_BPE: {
return pimpl->token_to_id.at(unicode_byte_to_utf8(ch));
}
case LLAMA_VOCAB_TYPE_PLAMO2: {
// PLaMo-2 uses byte tokens in format <0xXX>
case LLAMA_VOCAB_TYPE_PLAMO2:
case LLAMA_VOCAB_TYPE_VAETKI: {
// PLaMo-2/VAETKI uses byte tokens in format <0xXX>
char hex_str[8];
snprintf(hex_str, sizeof(hex_str), "<0x%02X>", ch);
return pimpl->token_to_id.at(hex_str);