working on support, now working on building graph

This commit is contained in:
ryan-mangeno 2025-08-25 16:15:40 -04:00
parent 6643c5a852
commit ac67fc6887
6 changed files with 265 additions and 13 deletions

View File

@ -18,6 +18,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_STARCODER, "starcoder" }, { LLM_ARCH_STARCODER, "starcoder" },
{ LLM_ARCH_REFACT, "refact" }, { LLM_ARCH_REFACT, "refact" },
{ LLM_ARCH_BERT, "bert" }, { LLM_ARCH_BERT, "bert" },
{ LLM_ARCH_MODERN_BERT, "modern-bert" },
{ LLM_ARCH_NOMIC_BERT, "nomic-bert" }, { LLM_ARCH_NOMIC_BERT, "nomic-bert" },
{ LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" }, { LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
{ LLM_ARCH_NEO_BERT, "neo-bert" }, { LLM_ARCH_NEO_BERT, "neo-bert" },
@ -505,6 +506,21 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_CLS_OUT, "cls.output" }, { LLM_TENSOR_CLS_OUT, "cls.output" },
}, },
}, },
{
LLM_ARCH_MODERN_BERT,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
},
},
{ {
LLM_ARCH_NOMIC_BERT, LLM_ARCH_NOMIC_BERT,
{ {

View File

@ -1375,7 +1375,9 @@ ggml_tensor * llm_graph_context::build_attn(
// [TAG_NO_CACHE_PAD] // [TAG_NO_CACHE_PAD]
// TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
assert(!ubatch.equal_seqs()); LLAMA_LOG_INFO("ubatch.equal_seqs() = %d, n_seqs = %d\n", ubatch.equal_seqs(), ubatch.n_seqs);
//assert(!ubatch.equal_seqs());
ggml_tensor * q = q_cur; ggml_tensor * q = q_cur;
ggml_tensor * k = k_cur; ggml_tensor * k = k_cur;

View File

@ -451,6 +451,7 @@ void llama_model::load_arch(llama_model_loader & ml) {
} }
void llama_model::load_hparams(llama_model_loader & ml) { void llama_model::load_hparams(llama_model_loader & ml) {
const gguf_context * ctx = ml.meta.get(); const gguf_context * ctx = ml.meta.get();
// get metadata as string // get metadata as string
@ -464,6 +465,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
gguf_kv.emplace(name, value); gguf_kv.emplace(name, value);
} }
// get general kv // get general kv
ml.get_key(LLM_KV_GENERAL_NAME, name, false); ml.get_key(LLM_KV_GENERAL_NAME, name, false);
@ -584,6 +586,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
} }
// arch-specific KVs // arch-specific KVs
LLAMA_LOG_INFO("Switching Arch\n");
switch (arch) { switch (arch) {
case LLM_ARCH_LLAMA: case LLM_ARCH_LLAMA:
{ {
@ -757,6 +760,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN; default: type = LLM_TYPE_UNKNOWN;
} }
} break; } break;
case LLM_ARCH_MODERN_BERT:
{
//ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
LLAMA_LOG_INFO("Switching Modern Bert Arch\n");
switch (hparams.n_layer) {
case 12:
type = LLM_TYPE_47M; break; // granite-embeddings-mall
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_JINA_BERT_V2: case LLM_ARCH_JINA_BERT_V2:
{ {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@ -1888,7 +1901,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
void llama_model::load_vocab(llama_model_loader & ml) { void llama_model::load_vocab(llama_model_loader & ml) {
const auto kv = LLM_KV(arch); const auto kv = LLM_KV(arch);
vocab.load(ml, kv); vocab.load(ml, kv);
} }
@ -2022,6 +2034,22 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
const int64_t n_expert_used = hparams.n_expert_used; const int64_t n_expert_used = hparams.n_expert_used;
const int64_t n_ctx_train = hparams.n_ctx_train; const int64_t n_ctx_train = hparams.n_ctx_train;
LLAMA_LOG_INFO("n_head = %lld\n", (long long) n_head);
LLAMA_LOG_INFO("n_head_kv = %lld\n", (long long) n_head_kv);
LLAMA_LOG_INFO("n_embd = %lld\n", (long long) n_embd);
LLAMA_LOG_INFO("n_embd_k_gqa = %lld\n", (long long) n_embd_k_gqa);
LLAMA_LOG_INFO("n_embd_v_gqa = %lld\n", (long long) n_embd_v_gqa);
LLAMA_LOG_INFO("n_embd_head_k = %lld\n", (long long) n_embd_head_k);
LLAMA_LOG_INFO("n_embd_head_v = %lld\n", (long long) n_embd_head_v);
LLAMA_LOG_INFO("n_ff = %lld\n", (long long) n_ff);
LLAMA_LOG_INFO("n_embd_gqa = %lld\n", (long long) n_embd_gqa);
LLAMA_LOG_INFO("n_vocab = %lld\n", (long long) n_vocab);
LLAMA_LOG_INFO("n_token_types = %lld\n", (long long) n_token_types);
LLAMA_LOG_INFO("n_rot = %lld\n", (long long) n_rot);
LLAMA_LOG_INFO("n_expert = %lld\n", (long long) n_expert);
LLAMA_LOG_INFO("n_expert_used = %lld\n", (long long) n_expert_used);
LLAMA_LOG_INFO("n_ctx_train = %lld\n", (long long) n_ctx_train);
if (n_expert > 0 && hparams.n_expert_used == 0) { if (n_expert > 0 && hparams.n_expert_used == 0) {
throw std::runtime_error("model has expert layers but no expert layers are used"); throw std::runtime_error("model has expert layers but no expert layers are used");
} }
@ -2033,7 +2061,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
auto create_tensor = [&](const LLM_TN_IMPL & tn, const std::initializer_list<int64_t> & ne, int flags) -> ggml_tensor * { auto create_tensor = [&](const LLM_TN_IMPL & tn, const std::initializer_list<int64_t> & ne, int flags) -> ggml_tensor * {
ggml_tensor * t_meta = ml.get_tensor_meta(tn.str().c_str()); ggml_tensor * t_meta = ml.get_tensor_meta(tn.str().c_str());
LLAMA_LOG_INFO("Creating Tensor: %s\n", tn.str().c_str());
if (!t_meta) { if (!t_meta) {
if (flags & TENSOR_NOT_REQUIRED) { if (flags & TENSOR_NOT_REQUIRED) {
return nullptr; return nullptr;
@ -2108,7 +2136,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
} }
ggml_backend_buffer_type_t buft = nullptr; ggml_backend_buffer_type_t buft = nullptr;
// check overrides // check overrides
if (ml.tensor_buft_overrides) { if (ml.tensor_buft_overrides) {
std::string tensor_name = tn.str(); std::string tensor_name = tn.str();
@ -2156,7 +2183,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
first_moved_to_buft = buft; first_moved_to_buft = buft;
} }
} }
ggml_context * ctx = ctx_for_buft(buft); ggml_context * ctx = ctx_for_buft(buft);
// if duplicated, check if the original tensor was allocated in the same buffer type context and avoid creating a new one // if duplicated, check if the original tensor was allocated in the same buffer type context and avoid creating a new one
@ -2614,11 +2640,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_NOMIC_BERT_MOE:
{ {
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED); type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED);
tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0);
if (arch == LLM_ARCH_BERT) { if (arch == LLM_ARCH_BERT) {
pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0);
cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED);
cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED);
@ -2626,14 +2655,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
} }
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0);
for (int i = 0; i < n_layer; ++i) { for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i]; auto & layer = layers[i];
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
if (!layer.wqkv) { if (!layer.wqkv) {
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
@ -2648,6 +2674,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED);
layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0);
layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0);
@ -2657,6 +2684,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0);
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
} else { } else {
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
@ -2671,6 +2699,33 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0);
layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0);
}
} break;
case LLM_ARCH_MODERN_BERT:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
for(int i = 0; i < n_layer; ++i) {
auto& layer = layers[i];
// layer 0 uses identity so we dont need weights for said layer
if ( i != 0 ) {
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
}
else{
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED);
}
layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_ff, n_embd} , 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_embd, n_ff * 2}, 0);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
} }
} break; } break;
case LLM_ARCH_NEO_BERT: case LLM_ARCH_NEO_BERT:
@ -7498,6 +7553,175 @@ struct llm_build_bert : public llm_graph_context {
} }
}; };
struct llm_build_modern_bert : public llm_graph_context {
llm_build_modern_bert(const llama_model & model, const llm_graph_params & params)
: llm_graph_context(params) {
const int64_t n_embd = hparams.n_embd;
const int64_t n_layer = hparams.n_layer;
const int64_t n_head = hparams.n_head();
const int64_t n_head_kv = hparams.n_head_kv();
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); // == n_head_kv * n_embd_head
const int64_t n_tokens = ubatch.n_tokens;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
// RoPE params
const int32_t rope_type = LLAMA_ROPE_TYPE_NEOX; // ModernBERT uses rotary
const int32_t n_rot = hparams.n_rot;
const int32_t n_ctx_orig = hparams.n_ctx_train;
ggml_tensor * cur;
ggml_tensor * inpL;
ggml_tensor * inp_pos = nullptr;
// ModernBERT needs positions for RoPE
inp_pos = build_inp_pos();
// 1) embeddings (token + optional type), NO absolute pos embed
inpL = build_inp_embd(model.tok_embd);
if (model.type_embd) {
ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0);
inpL = ggml_add(ctx0, inpL, type_row0);
}
cb(inpL, "inp_embd", -1);
// 2) embeddings LayerNorm (embeddings.norm)
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
cb(inpL, "inp_norm", -1);
auto * inp_attn = build_attn_inp_no_cache();
ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * x = inpL;
// pre-attention norm (attn_norm). Layer 0 may be Identity() -> nullptr
ggml_tensor * x_attn_in = x;
if (model.layers[il].attn_norm) {
x_attn_in = build_norm(x,
model.layers[il].attn_norm,
model.layers[il].attn_norm_b,
LLM_NORM, il);
cb(x_attn_in, "attn_pre_norm", il);
} else {
cb(x_attn_in, "attn_pre_norm_identity", il);
}
// Attention: fused Wqkv -> split -> heads -> RoPE(Q,K) -> attn -> Wo
ggml_tensor * qkv = nullptr;
ggml_tensor * Qcur;
ggml_tensor * Kcur;
ggml_tensor * Vcur;
GGML_ASSERT(model.layers[il].wqkv); // ModernBERT uses fused QKV
qkv = build_lora_mm(model.layers[il].wqkv, x_attn_in);
cb(qkv, "wqkv", il);
if (model.layers[il].bqkv) {
qkv = ggml_add(ctx0, qkv, model.layers[il].bqkv);
cb(qkv, "bqkv", il);
}
// Fused layout: [ (n_embd + 2*n_embd_gqa), n_tokens ]
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd, n_tokens, qkv->nb[1], 0*sizeof(float)*(n_embd)));
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], 1*sizeof(float)*(n_embd)));
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_gqa, n_tokens, qkv->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
// Optional per Q/K
if (model.layers[il].attn_q_norm) {
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, LLM_NORM, il);
}
if (model.layers[il].attn_k_norm) {
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, LLM_NORM, il);
}
// Heads
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
// RoPE (NEOX) on Q and K
Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur_rope", il);
cb(Kcur, "Kcur_rope", il);
cb(Vcur, "Vcur", il);
ggml_tensor * attn_out = build_attn(
inp_attn,
model.layers[il].wo, model.layers[il].bo, // Wo, optional bias
Qcur, Kcur, Vcur,
/*K_cache*/ nullptr,
/*V_cache*/ nullptr,
1.0f / sqrtf(float(n_embd_head)),
il);
cb(attn_out, "attn_out", il);
// Residual after attention
ggml_tensor * cur_attn = ggml_add(ctx0, attn_out, x);
// If we subselect outputs, do it at the last layer after attn resid
if (il == n_layer - 1 && inp_out_ids) {
cur_attn = ggml_get_rows(ctx0, cur_attn, inp_out_ids);
x = ggml_get_rows(ctx0, x, inp_out_ids);
}
// 5) pre-MLP norm (mlp_norm)
ggml_tensor * h = build_norm(cur_attn,
model.layers[il].ffn_norm,
model.layers[il].ffn_norm_b,
LLM_NORM, il);
cb(h, "mlp_pre_norm", il);
// 6) MLP (prefer GEGLU if gate exists or up has 2*n_ff rows)
ggml_tensor * mlp_out = nullptr;
const bool has_gate_tensor = (model.layers[il].ffn_gate != nullptr);
const bool up_is_2x = (model.layers[il].ffn_up && model.layers[il].ffn_up->ne[0] == 2*hparams.n_ff());
if (has_gate_tensor || up_is_2x) {
mlp_out = build_ffn(
h,
model.layers[il].ffn_up, /*up_b*/ nullptr, /*up_shexp*/ nullptr,
model.layers[il].ffn_gate, /*gate_b*/ nullptr, /*gate_shexp*/ nullptr,
model.layers[il].ffn_down, /*down_b*/ nullptr, /*down_shexp*/ nullptr,
/*expert_scores*/ nullptr,
LLM_FFN_GEGLU, LLM_FFN_PAR, il);
cb(mlp_out, "ffn_out_geglu", il);
} else {
mlp_out = build_ffn(
h,
model.layers[il].ffn_up, /*up_b*/ nullptr, /*up_shexp*/ nullptr,
/*gate*/ nullptr, /*gate_b*/ nullptr, /*gate_shexp*/ nullptr,
model.layers[il].ffn_down, /*down_b*/ nullptr, /*down_shexp*/ nullptr,
/*expert_scores*/ nullptr,
LLM_FFN_GELU, LLM_FFN_SEQ, il);
cb(mlp_out, "ffn_out_gelu", il);
}
// 7) Residual after MLP
ggml_tensor * cur_layer = ggml_add(ctx0, mlp_out, cur_attn);
// 8) feed into next layer
inpL = cur_layer;
}
// 9) final model norm (final_norm)
cur = build_norm(inpL, model.output_norm, model.output_norm_b, LLM_NORM, -1);
cb(cur, "final_norm", -1);
res->t_embd = cur;
ggml_build_forward_expand(gf, cur);
}
};
struct llm_build_neo_bert : public llm_graph_context { struct llm_build_neo_bert : public llm_graph_context {
llm_build_neo_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { llm_build_neo_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_head = hparams.n_embd_head_v;
@ -18186,6 +18410,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{ {
llm = std::make_unique<llm_build_bert>(*this, params); llm = std::make_unique<llm_build_bert>(*this, params);
} break; } break;
case LLM_ARCH_MODERN_BERT:
{
llm = std::make_unique<llm_build_modern_bert>(*this, params);
} break;
case LLM_ARCH_NEO_BERT: case LLM_ARCH_NEO_BERT:
{ {
llm = std::make_unique<llm_build_neo_bert>(*this, params); llm = std::make_unique<llm_build_neo_bert>(*this, params);
@ -18666,6 +18894,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_GROK: case LLM_ARCH_GROK:
case LLM_ARCH_DBRX: case LLM_ARCH_DBRX:
case LLM_ARCH_BERT: case LLM_ARCH_BERT:
case LLM_ARCH_MODERN_BERT:
case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT:
case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_NOMIC_BERT_MOE:
case LLM_ARCH_STABLELM: case LLM_ARCH_STABLELM:

View File

@ -23,6 +23,7 @@ enum llm_type {
LLM_TYPE_17M, LLM_TYPE_17M,
LLM_TYPE_22M, LLM_TYPE_22M,
LLM_TYPE_33M, LLM_TYPE_33M,
LLM_TYPE_47M,
LLM_TYPE_60M, LLM_TYPE_60M,
LLM_TYPE_70M, LLM_TYPE_70M,
LLM_TYPE_80M, LLM_TYPE_80M,

View File

@ -1661,10 +1661,13 @@ private:
void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
struct gguf_context * ctx = ml.meta.get(); struct gguf_context * ctx = ml.meta.get();
LLAMA_LOG_INFO("Determining Vocab Type\n");
// determine vocab type // determine vocab type
{ {
ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model); ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model);
ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false); ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
LLAMA_LOG_INFO("pre tokenizer model: %s\n", tokenizer_pre.c_str());
LLAMA_LOG_INFO("tokenizer model: %s\n", tokenizer_model.c_str());
ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, n_token_types, false); ml.get_key(LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, n_token_types, false);
@ -1813,7 +1816,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
LLAMA_LOG_WARN("%s: ************************************ \n", __func__); LLAMA_LOG_WARN("%s: ************************************ \n", __func__);
LLAMA_LOG_WARN("%s: \n", __func__); LLAMA_LOG_WARN("%s: \n", __func__);
pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
} else if (tokenizer_pre == "default") { } else if (tokenizer_pre == "default" || tokenizer_pre == "modern-bert") {
pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
} else if ( } else if (
tokenizer_pre == "llama3" || tokenizer_pre == "llama3" ||

View File

@ -126,6 +126,7 @@ static int llama_model_load(const std::string & fname, std::vector<std::string>
if (!model.load_tensors(ml)) { if (!model.load_tensors(ml)) {
return -2; return -2;
} }
} catch (const std::exception & err) { } catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what()); LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what());
return -1; return -1;