mtp-batch (wip): merge mtp and model graph
This commit is contained in:
parent
1318b2de82
commit
042eb8a829
|
|
@ -729,7 +729,8 @@ bool llama_context::apply_adapter_cvec(
|
|||
return cvec.apply(model, data, len, n_embd, il_start, il_end);
|
||||
}
|
||||
|
||||
llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
|
||||
llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret,
|
||||
bool do_mtp_kv_update) {
|
||||
if (mctx && !mctx->apply()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
|
||||
ret = GGML_STATUS_FAILED;
|
||||
|
|
@ -741,7 +742,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
|
|||
|
||||
// the new graph parameters
|
||||
// in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters
|
||||
const auto gparams = graph_params(res, ubatch, mctx, gtype);
|
||||
const auto gparams = graph_params(res, ubatch, mctx, gtype, do_mtp_kv_update);
|
||||
|
||||
if (!graph_reuse_disable && res->can_reuse(gparams)) {
|
||||
//LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__);
|
||||
|
|
@ -781,7 +782,15 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
|
|||
//LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
|
||||
}
|
||||
|
||||
const int64_t t_exec_start_us = ggml_time_us();
|
||||
const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1);
|
||||
const int64_t t_exec_end_us = ggml_time_us();
|
||||
LLAMA_LOG_INFO(
|
||||
"[PERF] Graph compute time: %.2f ms (ubatch_size: %u, MTP path: %s)\n",
|
||||
(t_exec_end_us - t_exec_start_us) / 1000.0,
|
||||
ubatch.n_tokens,
|
||||
do_mtp_kv_update ? "yes" : "no"
|
||||
);
|
||||
if (status != GGML_STATUS_SUCCESS) {
|
||||
LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
|
||||
ret = status;
|
||||
|
|
@ -850,7 +859,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|||
cparams.causal_attn = false;
|
||||
|
||||
ggml_status status;
|
||||
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status);
|
||||
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status, false);
|
||||
|
||||
cparams.causal_attn = causal_attn_org;
|
||||
|
||||
|
|
@ -1092,7 +1101,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
}
|
||||
|
||||
ggml_status status;
|
||||
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
|
||||
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status, do_mtp_kv_update);
|
||||
|
||||
if (!res) {
|
||||
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
|
||||
|
|
@ -1130,39 +1139,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
||||
//}
|
||||
|
||||
if (do_mtp_kv_update) {
|
||||
LLAMA_LOG_INFO(
|
||||
"[MTP BATCHING] Processando MTP KV update para um ubatch de %u tokens.\n",
|
||||
ubatch.n_tokens
|
||||
);
|
||||
auto res_mtp = std::make_unique<llm_graph_result>(graph_max_nodes());
|
||||
|
||||
auto params_mtp = mtp_graph_params(res_mtp.get(), ubatch, mctx.get());
|
||||
ggml_backend_sched_t sched_mtp = params_mtp.sched;
|
||||
|
||||
auto * gf_mtp = model.build_mtp_graph(params_mtp);
|
||||
if (gf_mtp) {
|
||||
ggml_backend_sched_alloc_graph(sched_mtp, gf_mtp);
|
||||
|
||||
ggml_tensor* prev_embedding_tensor = res->get_embd();
|
||||
ggml_tensor* embd_input_mtp = ggml_get_tensor(res_mtp->get_ctx(), "mtp_prev_embeddings_batch_input");
|
||||
|
||||
// ggml_backend_tensor_set(embd_input_mtp, prev_embedding_tensor->data, 0, ggml_nbytes(prev_embedding_tensor));
|
||||
ggml_backend_tensor_copy(prev_embedding_tensor, embd_input_mtp);
|
||||
|
||||
ggml_backend_sched_graph_compute(sched_mtp, gf_mtp);
|
||||
|
||||
if (ubatch.output[0]) {
|
||||
struct ggml_tensor * logits_mtp = res_mtp->get_logits();
|
||||
if (logits_mtp) {
|
||||
float * logits_dest = logits + n_outputs_prev * n_vocab;
|
||||
ggml_backend_tensor_get(logits_mtp, logits_dest, 0, ggml_nbytes(logits_mtp));
|
||||
}
|
||||
}
|
||||
}
|
||||
ggml_backend_sched_free(sched_mtp);
|
||||
}
|
||||
|
||||
auto * t_logits = res->get_logits();
|
||||
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
||||
embd_tensor = res->get_embd();
|
||||
|
|
@ -1442,7 +1418,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
|
|||
|
||||
auto * res = gf_res_reserve.get();
|
||||
|
||||
const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT);
|
||||
const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT, false);
|
||||
|
||||
res->reset();
|
||||
|
||||
|
|
@ -1462,8 +1438,9 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
|
|||
llm_graph_params llama_context::graph_params(
|
||||
llm_graph_result * res,
|
||||
const llama_ubatch & ubatch,
|
||||
const llama_memory_context_i * mctx,
|
||||
llm_graph_type gtype) const {
|
||||
const llama_memory_context_i * mctx,
|
||||
llm_graph_type gtype,
|
||||
bool update_mtp_kv) const {
|
||||
return {
|
||||
/*.arch =*/ model.arch,
|
||||
/*.hparams =*/ model.hparams,
|
||||
|
|
@ -1476,36 +1453,13 @@ llm_graph_params llama_context::graph_params(
|
|||
/*.loras =*/ &loras,
|
||||
/*.mctx =*/ mctx,
|
||||
/*.cross =*/ &cross,
|
||||
/*.update_mtp_kv =*/ update_mtp_kv,
|
||||
/*.n_outputs =*/ n_outputs,
|
||||
/*.cb =*/ graph_get_cb(),
|
||||
/*.res =*/ res,
|
||||
};
|
||||
}
|
||||
|
||||
llm_graph_params llama_context::mtp_graph_params(
|
||||
llm_graph_result * res,
|
||||
const llama_ubatch& ubatch,
|
||||
const llama_memory_context_i * mctx) {
|
||||
size_t n_nodes = std::max<uint32_t>(1024u, 8u * 8u * (((model.hparams.nextn_predict_layers + 1) * model.n_tensors()) / model.hparams.n_layer));
|
||||
ggml_backend_sched_t temp_sched = create_temp_scheduler(n_nodes);
|
||||
return {
|
||||
/*.arch =*/ model.arch,
|
||||
/*.hparams =*/ model.hparams,
|
||||
/*.cparams =*/ cparams,
|
||||
/*.ubatch =*/ ubatch,
|
||||
/*.gtype =*/ LLM_GRAPH_TYPE_DECODER,
|
||||
/*.sched =*/ temp_sched,
|
||||
/*.backend_cpu =*/ backend_cpu,
|
||||
/*.cvec =*/ &cvec,
|
||||
/*.loras =*/ &loras,
|
||||
/*.mctx =*/ mctx,
|
||||
/*.cross =*/ &cross,
|
||||
/*.n_outputs =*/ 1,
|
||||
/*.cb =*/ graph_get_cb(temp_sched),
|
||||
/*.res =*/ res,
|
||||
};
|
||||
}
|
||||
|
||||
std::unique_ptr<llama_memory_context_i> llama_context::mtp_memory_batch(const llama_batch& batch_inp) {
|
||||
const auto& vocab = model.vocab;
|
||||
const auto& hparams = model.hparams;
|
||||
|
|
@ -2240,7 +2194,7 @@ void llama_context::opt_epoch_iter(
|
|||
|
||||
auto * res = gf_res_prev.get();
|
||||
|
||||
const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT);
|
||||
const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT, false);
|
||||
|
||||
res->reset();
|
||||
|
||||
|
|
|
|||
|
|
@ -99,7 +99,8 @@ struct llama_context {
|
|||
const llama_ubatch & ubatch,
|
||||
llm_graph_type gtype,
|
||||
llama_memory_context_i * mctx,
|
||||
ggml_status & ret);
|
||||
ggml_status & ret,
|
||||
const bool do_mtp_kv_update);
|
||||
|
||||
int encode(const llama_batch & batch_inp);
|
||||
int decode(const llama_batch & batch_inp);
|
||||
|
|
@ -200,8 +201,6 @@ public:
|
|||
// reserve a graph with a dummy ubatch of the specified size
|
||||
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
|
||||
|
||||
llm_graph_params mtp_graph_params(llm_graph_result * res, const llama_ubatch & ubatch, const llama_memory_context_i * mctx);
|
||||
|
||||
void set_logits_ith(struct ggml_tensor * logit_override, ggml_backend_sched_t sched_override, int32_t i);
|
||||
|
||||
ggml_backend_sched_t create_temp_scheduler(size_t n_nodes);
|
||||
|
|
@ -213,7 +212,8 @@ private:
|
|||
llm_graph_result * res,
|
||||
const llama_ubatch & ubatch,
|
||||
const llama_memory_context_i * mctx,
|
||||
llm_graph_type gtype) const;
|
||||
llm_graph_type gtype,
|
||||
bool update_mtp_kv) const;
|
||||
|
||||
llm_graph_cb graph_get_cb(ggml_backend_sched * sched_override = nullptr) const;
|
||||
|
||||
|
|
|
|||
|
|
@ -402,6 +402,7 @@ struct llm_graph_params {
|
|||
const llama_adapter_loras * loras;
|
||||
const llama_memory_context_i * mctx;
|
||||
const llama_cross * cross;
|
||||
bool update_mtp_kv;
|
||||
|
||||
uint32_t n_outputs;
|
||||
|
||||
|
|
|
|||
|
|
@ -13787,7 +13787,8 @@ struct llm_build_glm4 : public llm_graph_context {
|
|||
};
|
||||
|
||||
struct llm_build_glm4_moe : public llm_graph_context {
|
||||
llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params, bool build_mtp_path)
|
||||
: llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
|
@ -13932,68 +13933,57 @@ struct llm_build_glm4_moe : public llm_graph_context {
|
|||
cur = inpL;
|
||||
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
// cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
if (build_mtp_path) {
|
||||
const int il_mtp = hparams.n_layer - 1;
|
||||
const auto & mtp_layer = model.layers[il_mtp];
|
||||
|
||||
ggml_tensor * mtp_logits = build_mtp_tail(mtp_layer, cur, n_embd_head);
|
||||
res->t_logits = mtp_logits;
|
||||
} else {
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
res->t_logits = cur;
|
||||
}
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
ggml_build_forward_expand(gf, res->t_logits);
|
||||
}
|
||||
};
|
||||
|
||||
struct llm_build_glm4_moe_mtp : public llm_graph_context {
|
||||
llm_build_glm4_moe_mtp(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
||||
private:
|
||||
ggml_tensor * build_mtp_tail(const llama_layer & mtp_layer, ggml_tensor * prev_embeddings,
|
||||
int64_t n_embd_head
|
||||
) {
|
||||
const int il = hparams.n_layer - 1;
|
||||
const auto & mtp_layer = model.layers[il];
|
||||
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
ggml_tensor* prev_embeddings_batch = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_embd, n_tokens);
|
||||
ggml_set_name(prev_embeddings_batch, "mtp_prev_embeddings_batch_input");
|
||||
ggml_set_input(prev_embeddings_batch);
|
||||
|
||||
ggml_tensor * token_emb = build_inp_embd_mtp(mtp_layer.nextn.embed_tokens);
|
||||
|
||||
ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il);
|
||||
ggml_tensor * hidden_state_norm = build_norm(prev_embeddings_batch, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il);
|
||||
|
||||
ggml_tensor * hidden_state_norm = build_norm(prev_embeddings, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il);
|
||||
|
||||
ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0);
|
||||
|
||||
ggml_tensor* cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined);
|
||||
|
||||
// now proceed through last layer (skipped in main model)
|
||||
ggml_tensor * inpSA = cur;
|
||||
|
||||
// Pre-attention norm for the MTP block
|
||||
ggml_tensor* attn_inp = build_norm(cur, mtp_layer.attn_norm, NULL, LLM_NORM_RMS, il);
|
||||
cur = build_norm(cur, mtp_layer.attn_norm, NULL, LLM_NORM_RMS, il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
ggml_tensor * Qcur = build_lora_mm(mtp_layer.wq, cur);
|
||||
if (mtp_layer.bq) {
|
||||
Qcur = ggml_add(ctx0, Qcur, mtp_layer.bq);
|
||||
}
|
||||
if (mtp_layer.bq) Qcur = ggml_add(ctx0, Qcur, mtp_layer.bq);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(mtp_layer.wk, cur);
|
||||
if (mtp_layer.bk) {
|
||||
Kcur = ggml_add(ctx0, Kcur, mtp_layer.bk);
|
||||
}
|
||||
if (mtp_layer.bk) Kcur = ggml_add(ctx0, Kcur, mtp_layer.bk);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(mtp_layer.wv, cur);
|
||||
if (mtp_layer.bv) {
|
||||
Vcur = ggml_add(ctx0, Vcur, mtp_layer.bv);
|
||||
}
|
||||
if (mtp_layer.bv) Vcur = ggml_add(ctx0, Vcur, mtp_layer.bv);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
|
|
@ -14025,10 +14015,10 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context {
|
|||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
mtp_layer.wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
mtp_layer.wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
|
|
@ -14068,9 +14058,7 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context {
|
|||
cur = build_norm(cur, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il);
|
||||
cur = build_lora_mm(mtp_layer.nextn.shared_head_head, cur);
|
||||
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, res->t_logits);
|
||||
return cur;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -18299,8 +18287,12 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||
}
|
||||
|
||||
ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
||||
const int64_t t_start_us = ggml_time_us();
|
||||
|
||||
std::unique_ptr<llm_graph_context> llm;
|
||||
|
||||
const bool build_mtp = params.update_mtp_kv;
|
||||
|
||||
switch (arch) {
|
||||
case LLM_ARCH_LLAMA:
|
||||
{
|
||||
|
|
@ -18519,7 +18511,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
|||
} break;
|
||||
case LLM_ARCH_GLM4_MOE:
|
||||
{
|
||||
llm = std::make_unique<llm_build_glm4_moe>(*this, params);
|
||||
llm = std::make_unique<llm_build_glm4_moe>(*this, params, build_mtp);
|
||||
} break;
|
||||
case LLM_ARCH_BITNET:
|
||||
{
|
||||
|
|
@ -18660,22 +18652,12 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
|||
|
||||
// add on pooling layer
|
||||
llm->build_pooling(cls, cls_b, cls_out, cls_out_b);
|
||||
|
||||
return llm->res->get_gf();
|
||||
}
|
||||
|
||||
ggml_cgraph * llama_model::build_mtp_graph(const llm_graph_params& params) const {
|
||||
std::unique_ptr<llm_graph_context> llm;
|
||||
|
||||
switch (arch) {
|
||||
case LLM_ARCH_GLM4_MOE:
|
||||
{
|
||||
llm = std::make_unique<llm_build_glm4_moe_mtp>(*this, params);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
const int64_t t_end_us = ggml_time_us(); // Fim do cronômetro
|
||||
LLAMA_LOG_INFO(
|
||||
"[PERF] Graph build time: %.2f ms (MTP path: %s)\n",
|
||||
(t_end_us - t_start_us) / 1000.0,
|
||||
build_mtp ? "yes" : "no"
|
||||
);
|
||||
return llm->res->get_gf();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -475,7 +475,6 @@ struct llama_model {
|
|||
|
||||
// TODO: move this to new llm_arch_model_i interface
|
||||
ggml_cgraph * build_graph(const llm_graph_params & params) const;
|
||||
ggml_cgraph * build_mtp_graph(const llm_graph_params& params) const;
|
||||
|
||||
private:
|
||||
struct impl;
|
||||
|
|
|
|||
|
|
@ -1739,7 +1739,7 @@ struct server_queue {
|
|||
|
||||
while (true) {
|
||||
QUE_DBG("%s", "processing new tasks\n");
|
||||
|
||||
const int64_t t_turn_start_us = ggml_time_us();
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
if (!running) {
|
||||
|
|
@ -1762,7 +1762,11 @@ struct server_queue {
|
|||
QUE_DBG("%s", "update slots\n");
|
||||
|
||||
callback_update_slots();
|
||||
|
||||
const int64_t t_turn_end_us = ggml_time_us();
|
||||
SRV_DBG(
|
||||
"[PERF] Server turn time: %.2f ms\n",
|
||||
(t_turn_end_us - t_turn_start_us) / 1000.0
|
||||
);
|
||||
QUE_DBG("%s", "waiting for new tasks\n");
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
|
|
|
|||
Loading…
Reference in New Issue