diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 58b885c9cf..7a05491868 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -7550,6 +7550,7 @@ struct llm_build_bert : public llm_graph_context { } }; +template struct llm_build_modern_bert : public llm_graph_context { llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -18357,7 +18358,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } break; case LLM_ARCH_MODERN_BERT: { - llm = std::make_unique(*this, params); + llm = std::make_unique>(*this, params); } break; case LLM_ARCH_NEO_BERT: { diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 3268387908..00fbe4db1d 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -2487,6 +2487,13 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { for (const auto * token : {"", "", "<|endoftext|>"}) { _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false); } + } else if ( _contains_any(model_name, {"modern-bert"})) { + if ( token_to_id.count("MASK") == 0 ) { + LLAMA_LOG_WARN("%s: Mask token missing in vocab!\n", __func__); + } + else { + _set_token_attr("[MASK]", LLAMA_TOKEN_ATTR_LSTRIP, true); + } } } }