added mask check in vocab

This commit is contained in:
ryan-mangeno 2025-09-12 11:45:02 -04:00
parent 20d448a8d7
commit db4f5656e4
2 changed files with 9 additions and 1 deletions

View File

@ -7550,6 +7550,7 @@ struct llm_build_bert : public llm_graph_context {
} }
}; };
template <bool iswa>
struct llm_build_modern_bert : public llm_graph_context { struct llm_build_modern_bert : public llm_graph_context {
llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_head = hparams.n_embd_head_v;
@ -18357,7 +18358,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
} break; } break;
case LLM_ARCH_MODERN_BERT: case LLM_ARCH_MODERN_BERT:
{ {
llm = std::make_unique<llm_build_modern_bert>(*this, params); llm = std::make_unique<llm_build_modern_bert<true>>(*this, params);
} break; } break;
case LLM_ARCH_NEO_BERT: case LLM_ARCH_NEO_BERT:
{ {

View File

@ -2487,6 +2487,13 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
for (const auto * token : {"<unk>", "<s>", "<|endoftext|>"}) { for (const auto * token : {"<unk>", "<s>", "<|endoftext|>"}) {
_set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false); _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false);
} }
} else if ( _contains_any(model_name, {"modern-bert"})) {
if ( token_to_id.count("MASK") == 0 ) {
LLAMA_LOG_WARN("%s: Mask token missing in vocab!\n", __func__);
}
else {
_set_token_attr("[MASK]", LLAMA_TOKEN_ATTR_LSTRIP, true);
}
} }
} }
} }