some cleanup
This commit is contained in:
parent
ac67fc6887
commit
cc40378d27
|
|
@ -1377,7 +1377,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
||||||
// TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
|
// TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams
|
||||||
LLAMA_LOG_INFO("ubatch.equal_seqs() = %d, n_seqs = %d\n", ubatch.equal_seqs(), ubatch.n_seqs);
|
LLAMA_LOG_INFO("ubatch.equal_seqs() = %d, n_seqs = %d\n", ubatch.equal_seqs(), ubatch.n_seqs);
|
||||||
|
|
||||||
//assert(!ubatch.equal_seqs());
|
assert(!ubatch.equal_seqs());
|
||||||
|
|
||||||
ggml_tensor * q = q_cur;
|
ggml_tensor * q = q_cur;
|
||||||
ggml_tensor * k = k_cur;
|
ggml_tensor * k = k_cur;
|
||||||
|
|
|
||||||
|
|
@ -7578,7 +7578,7 @@ struct llm_build_modern_bert : public llm_graph_context {
|
||||||
// ModernBERT needs positions for RoPE
|
// ModernBERT needs positions for RoPE
|
||||||
inp_pos = build_inp_pos();
|
inp_pos = build_inp_pos();
|
||||||
|
|
||||||
// 1) embeddings (token + optional type), NO absolute pos embed
|
// embeddings (token + optional type), NO absolute pos embed
|
||||||
inpL = build_inp_embd(model.tok_embd);
|
inpL = build_inp_embd(model.tok_embd);
|
||||||
|
|
||||||
if (model.type_embd) {
|
if (model.type_embd) {
|
||||||
|
|
@ -7587,7 +7587,7 @@ struct llm_build_modern_bert : public llm_graph_context {
|
||||||
}
|
}
|
||||||
cb(inpL, "inp_embd", -1);
|
cb(inpL, "inp_embd", -1);
|
||||||
|
|
||||||
// 2) embeddings LayerNorm (embeddings.norm)
|
// embeddings LayerNorm (embeddings.norm)
|
||||||
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
|
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
|
||||||
cb(inpL, "inp_norm", -1);
|
cb(inpL, "inp_norm", -1);
|
||||||
|
|
||||||
|
|
@ -7673,14 +7673,14 @@ struct llm_build_modern_bert : public llm_graph_context {
|
||||||
x = ggml_get_rows(ctx0, x, inp_out_ids);
|
x = ggml_get_rows(ctx0, x, inp_out_ids);
|
||||||
}
|
}
|
||||||
|
|
||||||
// 5) pre-MLP norm (mlp_norm)
|
// pre-MLP norm (mlp_norm)
|
||||||
ggml_tensor * h = build_norm(cur_attn,
|
ggml_tensor * h = build_norm(cur_attn,
|
||||||
model.layers[il].ffn_norm,
|
model.layers[il].ffn_norm,
|
||||||
model.layers[il].ffn_norm_b,
|
model.layers[il].ffn_norm_b,
|
||||||
LLM_NORM, il);
|
LLM_NORM, il);
|
||||||
cb(h, "mlp_pre_norm", il);
|
cb(h, "mlp_pre_norm", il);
|
||||||
|
|
||||||
// 6) MLP (prefer GEGLU if gate exists or up has 2*n_ff rows)
|
// MLP (prefer GEGLU if gate exists or up has 2*n_ff rows)
|
||||||
ggml_tensor * mlp_out = nullptr;
|
ggml_tensor * mlp_out = nullptr;
|
||||||
const bool has_gate_tensor = (model.layers[il].ffn_gate != nullptr);
|
const bool has_gate_tensor = (model.layers[il].ffn_gate != nullptr);
|
||||||
const bool up_is_2x = (model.layers[il].ffn_up && model.layers[il].ffn_up->ne[0] == 2*hparams.n_ff());
|
const bool up_is_2x = (model.layers[il].ffn_up && model.layers[il].ffn_up->ne[0] == 2*hparams.n_ff());
|
||||||
|
|
@ -7705,14 +7705,14 @@ struct llm_build_modern_bert : public llm_graph_context {
|
||||||
cb(mlp_out, "ffn_out_gelu", il);
|
cb(mlp_out, "ffn_out_gelu", il);
|
||||||
}
|
}
|
||||||
|
|
||||||
// 7) Residual after MLP
|
// Residual after MLP
|
||||||
ggml_tensor * cur_layer = ggml_add(ctx0, mlp_out, cur_attn);
|
ggml_tensor * cur_layer = ggml_add(ctx0, mlp_out, cur_attn);
|
||||||
|
|
||||||
// 8) feed into next layer
|
// feed into next layer
|
||||||
inpL = cur_layer;
|
inpL = cur_layer;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 9) final model norm (final_norm)
|
// final model norm (final_norm)
|
||||||
cur = build_norm(inpL, model.output_norm, model.output_norm_b, LLM_NORM, -1);
|
cur = build_norm(inpL, model.output_norm, model.output_norm_b, LLM_NORM, -1);
|
||||||
cb(cur, "final_norm", -1);
|
cb(cur, "final_norm", -1);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1816,7 +1816,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||||
LLAMA_LOG_WARN("%s: ************************************ \n", __func__);
|
LLAMA_LOG_WARN("%s: ************************************ \n", __func__);
|
||||||
LLAMA_LOG_WARN("%s: \n", __func__);
|
LLAMA_LOG_WARN("%s: \n", __func__);
|
||||||
pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||||
} else if (tokenizer_pre == "default" || tokenizer_pre == "modern-bert") {
|
} else if (tokenizer_pre == "default" || tokenizer_pre == "modern-bert") /* need to fix modern-bert pre tokenizer */ {
|
||||||
pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||||
} else if (
|
} else if (
|
||||||
tokenizer_pre == "llama3" ||
|
tokenizer_pre == "llama3" ||
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue