mtp-batch (wip): merge glm graphs
This commit is contained in:
parent
042eb8a829
commit
df64508b93
|
|
@ -373,10 +373,24 @@ llama_token mtp_speculative_gen_draft(
|
|||
if (!smpl) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
const float * draft_input_hidden_state = llama_get_embeddings_ith(ctx, last_tok_idx);
|
||||
llama_set_draft_input_hidden_state(ctx, draft_input_hidden_state);
|
||||
|
||||
llama_batch mtp_batch = llama_batch_init(1, 0, 1);
|
||||
common_batch_add(mtp_batch, id_last, n_past, {0}, true);
|
||||
mtp_batch.update_mtp_kv = true;
|
||||
|
||||
LOG_INF(
|
||||
"[DEBUG-DRAFT-IN] Generating draft. id_last=%d, n_past=%d, last_tok_idx=%d\n",
|
||||
id_last, n_past, last_tok_idx
|
||||
);
|
||||
|
||||
mtp_batch.update_mtp_kv = false;
|
||||
mtp_batch.use_mtp_head = true;
|
||||
|
||||
LOG_INF("[DEBUG-DRAFT-CALL] Calling llama_decode for draft. update_mtp_kv=%s, use_mtp_head=%s\n",
|
||||
mtp_batch.update_mtp_kv ? "true" : "false",
|
||||
mtp_batch.use_mtp_head ? "true" : "false"
|
||||
);
|
||||
|
||||
llama_decode(ctx, mtp_batch);
|
||||
llama_batch_free(mtp_batch);
|
||||
|
|
@ -419,6 +433,7 @@ void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_d
|
|||
}
|
||||
|
||||
mtp_batch.update_mtp_kv = true;
|
||||
mtp_batch.use_mtp_head = true;
|
||||
|
||||
llama_decode(ctx, mtp_batch);
|
||||
|
||||
|
|
|
|||
|
|
@ -230,7 +230,8 @@ extern "C" {
|
|||
int32_t * n_seq_id;
|
||||
llama_seq_id ** seq_id;
|
||||
int8_t * logits; // TODO: rename this to "output"
|
||||
bool update_mtp_kv;
|
||||
bool update_mtp_kv;
|
||||
bool use_mtp_head;
|
||||
} llama_batch;
|
||||
|
||||
enum llama_model_kv_override_type {
|
||||
|
|
@ -1455,8 +1456,7 @@ extern "C" {
|
|||
ggml_opt_epoch_callback callback_train,
|
||||
ggml_opt_epoch_callback callback_eval);
|
||||
|
||||
// LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
|
||||
// const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx);
|
||||
LLAMA_API void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
|||
|
|
@ -841,6 +841,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
|
|||
/*n_seq_id =*/ nullptr,
|
||||
/*seq_id =*/ nullptr,
|
||||
/*logits =*/ nullptr,
|
||||
/*.use_mtp_head =*/ false,
|
||||
/*update_mtp_kv =*/ false,
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -730,7 +730,7 @@ bool llama_context::apply_adapter_cvec(
|
|||
}
|
||||
|
||||
llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret,
|
||||
bool do_mtp_kv_update) {
|
||||
bool do_mtp_kv_update, bool use_mtp_head) {
|
||||
if (mctx && !mctx->apply()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
|
||||
ret = GGML_STATUS_FAILED;
|
||||
|
|
@ -742,7 +742,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
|
|||
|
||||
// the new graph parameters
|
||||
// in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters
|
||||
const auto gparams = graph_params(res, ubatch, mctx, gtype, do_mtp_kv_update);
|
||||
const auto gparams = graph_params(res, ubatch, mctx, gtype, do_mtp_kv_update, use_mtp_head);
|
||||
|
||||
if (!graph_reuse_disable && res->can_reuse(gparams)) {
|
||||
//LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__);
|
||||
|
|
@ -773,6 +773,29 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
|
|||
}
|
||||
}
|
||||
|
||||
if (do_mtp_kv_update || (use_mtp_head && !do_mtp_kv_update)) { // If it is any MTP operation
|
||||
const char * target_tensor_name = "result_embd_pooled";
|
||||
ggml_tensor* hidden_states_input = ggml_get_tensor(res->get_ctx(), target_tensor_name);
|
||||
|
||||
const float * source_hidden_state = nullptr;
|
||||
if (do_mtp_kv_update) {
|
||||
// Cache warming uses the entire embeddings buffer
|
||||
source_hidden_state = this->embd;
|
||||
} else {
|
||||
// Draft generation uses the specific state
|
||||
source_hidden_state = this->draft_input_hidden_state;
|
||||
}
|
||||
|
||||
if (source_hidden_state != nullptr && hidden_states_input != nullptr) {
|
||||
ggml_backend_tensor_set(hidden_states_input, source_hidden_state, 0, ggml_nbytes(hidden_states_input));
|
||||
} else {
|
||||
LLAMA_LOG_ERROR("%s: MTP hidden state input tensor ('%s') not found or main embd buffer is null\n",
|
||||
__func__, target_tensor_name);
|
||||
ret = GGML_STATUS_FAILED;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// set the input data for the input tensors
|
||||
{
|
||||
//const auto t_start_us = ggml_time_us();
|
||||
|
|
@ -798,7 +821,12 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
|
|||
}
|
||||
|
||||
ret = GGML_STATUS_SUCCESS;
|
||||
|
||||
if (do_mtp_kv_update || use_mtp_head) {
|
||||
ggml_tensor * sum_tensor = ggml_get_tensor(res->get_ctx(), "mtp_input_sum");
|
||||
if (sum_tensor) {
|
||||
LLAMA_LOG_WARN("[DEBUG-SUM] MTP input sum node successfully created.\n");
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
|
|
@ -859,7 +887,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|||
cparams.causal_attn = false;
|
||||
|
||||
ggml_status status;
|
||||
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status, false);
|
||||
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status, false, false);
|
||||
|
||||
cparams.causal_attn = causal_attn_org;
|
||||
|
||||
|
|
@ -972,6 +1000,10 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|||
|
||||
int llama_context::decode(const llama_batch & batch_inp) {
|
||||
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
|
||||
LLAMA_LOG_WARN("[DEBUG-DECODE-ENTRY] Entering llama_decode. update_mtp_kv=%s, use_mtp_head=%s\n",
|
||||
batch_inp.update_mtp_kv ? "true" : "false",
|
||||
batch_inp.use_mtp_head ? "true" : "false"
|
||||
);
|
||||
|
||||
if (!memory) {
|
||||
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
|
||||
|
|
@ -1080,9 +1112,24 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
|
||||
int64_t n_outputs_prev = 0;
|
||||
const bool do_mtp_kv_update = batch_inp.update_mtp_kv;
|
||||
|
||||
const bool use_mtp_head = batch_inp.use_mtp_head;
|
||||
const bool is_prompt_warmup = batch_inp.n_tokens > 1 && (this->model.hparams.nextn_predict_layers > 0);
|
||||
|
||||
do {
|
||||
const auto & ubatch = mctx->get_ubatch();
|
||||
if (ubatch.n_tokens > 0) {
|
||||
std::string pos_str;
|
||||
for (uint32_t i = 0; i < std::min((uint32_t)5, ubatch.n_tokens); ++i) {
|
||||
pos_str += std::to_string(ubatch.pos[i]) + " ";
|
||||
}
|
||||
LLAMA_LOG_WARN(
|
||||
"[DEBUG-POS] ubatch_size=%u, update_mtp_kv=%s, use_mtp_head=%s. Posições: %s...\n",
|
||||
ubatch.n_tokens,
|
||||
batch_inp.update_mtp_kv ? "true" : "false",
|
||||
batch_inp.use_mtp_head ? "true" : "false",
|
||||
pos_str.c_str()
|
||||
);
|
||||
}
|
||||
|
||||
// count the outputs in this ubatch
|
||||
{
|
||||
|
|
@ -1101,7 +1148,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
}
|
||||
|
||||
ggml_status status;
|
||||
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status, do_mtp_kv_update);
|
||||
const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status, do_mtp_kv_update, use_mtp_head);
|
||||
|
||||
if (!res) {
|
||||
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
|
||||
|
|
@ -1139,6 +1186,17 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
||||
//}
|
||||
|
||||
// if (is_prompt_warmup) {
|
||||
// auto res_mtp = std::make_unique<llm_graph_result>(graph_max_nodes());
|
||||
// ggml_status status_mtp;
|
||||
|
||||
// process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status_mtp, do_mtp_kv_update, use_mtp_head);
|
||||
|
||||
// if (status_mtp != GGML_STATUS_SUCCESS) {
|
||||
// LLAMA_LOG_WARN("%s: Failure in MTP heating ubatch\n", __func__);
|
||||
// }
|
||||
// }
|
||||
|
||||
auto * t_logits = res->get_logits();
|
||||
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
||||
embd_tensor = res->get_embd();
|
||||
|
|
@ -1278,7 +1336,9 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
// overlap with device computation.
|
||||
ggml_backend_sched_reset(sched.get());
|
||||
}
|
||||
|
||||
if (!do_mtp_kv_update && !use_mtp_head) {
|
||||
LLAMA_LOG_WARN("[DEBUG-EMBD-WRITE] Main decode completed. ctx->embd (%p) now contains the hidden state for the next draft.\n", (void*)this->embd);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
@ -1418,7 +1478,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
|
|||
|
||||
auto * res = gf_res_reserve.get();
|
||||
|
||||
const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT, false);
|
||||
const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT, false, false);
|
||||
|
||||
res->reset();
|
||||
|
||||
|
|
@ -1440,7 +1500,8 @@ llm_graph_params llama_context::graph_params(
|
|||
const llama_ubatch & ubatch,
|
||||
const llama_memory_context_i * mctx,
|
||||
llm_graph_type gtype,
|
||||
bool update_mtp_kv) const {
|
||||
bool update_mtp_kv,
|
||||
bool use_mtp_head) const {
|
||||
return {
|
||||
/*.arch =*/ model.arch,
|
||||
/*.hparams =*/ model.hparams,
|
||||
|
|
@ -1454,6 +1515,7 @@ llm_graph_params llama_context::graph_params(
|
|||
/*.mctx =*/ mctx,
|
||||
/*.cross =*/ &cross,
|
||||
/*.update_mtp_kv =*/ update_mtp_kv,
|
||||
/*.use_mtp_head =*/ use_mtp_head,
|
||||
/*.n_outputs =*/ n_outputs,
|
||||
/*.cb =*/ graph_get_cb(),
|
||||
/*.res =*/ res,
|
||||
|
|
@ -2194,7 +2256,7 @@ void llama_context::opt_epoch_iter(
|
|||
|
||||
auto * res = gf_res_prev.get();
|
||||
|
||||
const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT, false);
|
||||
const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT, false, false);
|
||||
|
||||
res->reset();
|
||||
|
||||
|
|
@ -2983,79 +3045,6 @@ void llama_opt_epoch(
|
|||
callback_eval);
|
||||
}
|
||||
|
||||
// void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
|
||||
// const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) {
|
||||
|
||||
// const auto * model = llama_get_model(ctx);
|
||||
|
||||
// auto res_mtp = std::make_unique<llm_graph_result>(ctx->graph_max_nodes());
|
||||
// std::unique_ptr<llama_memory_context_i> mctx = ctx->mtp_memory_batch(batch_inp);
|
||||
|
||||
// std::vector<uint32_t> idxs;
|
||||
// idxs.push_back(n_past);
|
||||
// llama_kv_cache_unified::slot_info sinfo = {
|
||||
// /*.s0 =*/ 0,
|
||||
// /*.s1 =*/ 0,
|
||||
// /*.strm =*/ { 0 },
|
||||
// /*.idxs =*/ { idxs },
|
||||
// };
|
||||
// llama_kv_cache_unified::slot_info_vec_t sinfos;
|
||||
// sinfos.push_back(sinfo);
|
||||
|
||||
// static_cast<llama_kv_cache_unified_context*>(mctx.get())->set_sinfos(sinfos);
|
||||
// const auto& ubatch_mtp = mctx->get_ubatch();
|
||||
|
||||
// //llama_ubatch ubatch_mtp;
|
||||
// //ubatch_mtp.n_tokens = 1;
|
||||
// //ubatch_mtp.pos = &n_past;
|
||||
|
||||
// auto params_mtp = std::make_unique<llm_graph_params>(ctx->mtp_graph_params(res_mtp.get(), ubatch_mtp, mctx.get()));
|
||||
// ggml_backend_sched_t sched = params_mtp->sched;
|
||||
|
||||
// auto * last_embd = ctx->get_embeddings_ith(last_tok_idx);
|
||||
|
||||
// //if (mctx && !mctx->set_n_kv()) {
|
||||
// // LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
|
||||
// //}
|
||||
// static_cast<llama_kv_cache_unified_context*>(mctx.get())->set_n_kv();
|
||||
|
||||
// auto * gf = model->build_mtp_graph(*params_mtp);
|
||||
|
||||
// if (!gf) {
|
||||
// LLAMA_LOG_ERROR("%s: ERROR - The construction of the MTP graph failed (returned null).", __func__);
|
||||
// if (sched) ggml_backend_sched_free(sched);
|
||||
// return;
|
||||
// }
|
||||
|
||||
// ggml_backend_sched_reset(sched); // clear the allocation of the previous graph
|
||||
// ggml_backend_sched_alloc_graph(sched, gf); // explicitly allocate the new graph but do not execute it
|
||||
|
||||
// ggml_tensor * mtp_token_id_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_token_id_input");
|
||||
// ggml_backend_tensor_set(mtp_token_id_input, &last_token_id, 0, sizeof(last_token_id)); // copy data to the newly allocated graph tensors
|
||||
|
||||
// ggml_tensor * mtp_prev_embedding_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_prev_embedding_input");
|
||||
// ggml_backend_tensor_set(mtp_prev_embedding_input, last_embd, 0, ggml_nbytes(mtp_prev_embedding_input)); // copy data to the newly allocated graph tensors
|
||||
|
||||
// ggml_backend_sched_graph_compute(sched, gf); // execute the graph
|
||||
|
||||
// struct ggml_tensor * logits_mtp = res_mtp->get_logits();
|
||||
|
||||
// if (logits_mtp) {
|
||||
// float * logits_dest = ctx->get_logits_ith(last_tok_idx);
|
||||
// ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched, logits_mtp);
|
||||
// if (backend_res) {
|
||||
// // ggml_backend_tensor_get is the function for GPU->CPU copies.
|
||||
// // We are copying a single 32-bit integer.
|
||||
// ggml_backend_tensor_get(logits_mtp,
|
||||
// logits_dest, // Pointer to our C++ variable
|
||||
// 0, // Starting offset in bytes
|
||||
// ggml_nbytes(logits_mtp)); // Number of bytes to copy
|
||||
// } else {
|
||||
// LLAMA_LOG_ERROR("%s: ERROR - Could not obtain the backend for the logits tensor.", __func__);
|
||||
// }
|
||||
// } else {
|
||||
// LLAMA_LOG_WARN("%s: WARNING - The MTP graph did not produce a logit tensor.", __func__);
|
||||
// }
|
||||
|
||||
// ggml_backend_sched_free(sched);
|
||||
// }
|
||||
void llama_set_draft_input_hidden_state(struct llama_context * ctx, const float * hidden_state) {
|
||||
ctx->draft_input_hidden_state = hidden_state;
|
||||
}
|
||||
|
|
@ -61,6 +61,8 @@ struct llama_context {
|
|||
float * get_embeddings_seq(llama_seq_id seq_id);
|
||||
ggml_tensor * get_embeddings_tensor();
|
||||
|
||||
const float * draft_input_hidden_state = nullptr;
|
||||
|
||||
void attach_threadpool(
|
||||
ggml_threadpool_t threadpool,
|
||||
ggml_threadpool_t threadpool_batch);
|
||||
|
|
@ -100,7 +102,8 @@ struct llama_context {
|
|||
llm_graph_type gtype,
|
||||
llama_memory_context_i * mctx,
|
||||
ggml_status & ret,
|
||||
const bool do_mtp_kv_update);
|
||||
const bool do_mtp_kv_update,
|
||||
const bool use_mtp_head);
|
||||
|
||||
int encode(const llama_batch & batch_inp);
|
||||
int decode(const llama_batch & batch_inp);
|
||||
|
|
@ -213,7 +216,8 @@ private:
|
|||
const llama_ubatch & ubatch,
|
||||
const llama_memory_context_i * mctx,
|
||||
llm_graph_type gtype,
|
||||
bool update_mtp_kv) const;
|
||||
bool update_mtp_kv,
|
||||
bool use_mtp_head) const;
|
||||
|
||||
llm_graph_cb graph_get_cb(ggml_backend_sched * sched_override = nullptr) const;
|
||||
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ enum llm_graph_type {
|
|||
LLM_GRAPH_TYPE_DEFAULT,
|
||||
LLM_GRAPH_TYPE_ENCODER,
|
||||
LLM_GRAPH_TYPE_DECODER,
|
||||
LLM_GRAPH_TYPE_DRAFT,
|
||||
};
|
||||
|
||||
enum llm_ffn_op_type {
|
||||
|
|
@ -94,6 +95,20 @@ public:
|
|||
|
||||
using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>;
|
||||
|
||||
class llm_graph_input_mtp_states : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_mtp_states() = default;
|
||||
virtual ~llm_graph_input_mtp_states() = default;
|
||||
|
||||
void set_input(const llama_ubatch * /*ubatch*/) override {}
|
||||
|
||||
bool can_reuse(const llm_graph_params & /*params*/) override {
|
||||
return true;
|
||||
}
|
||||
|
||||
ggml_tensor * states = nullptr;
|
||||
};
|
||||
|
||||
class llm_graph_input_embd : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_embd() = default;
|
||||
|
|
@ -403,6 +418,7 @@ struct llm_graph_params {
|
|||
const llama_memory_context_i * mctx;
|
||||
const llama_cross * cross;
|
||||
bool update_mtp_kv;
|
||||
bool use_mtp_head;
|
||||
|
||||
uint32_t n_outputs;
|
||||
|
||||
|
|
@ -451,6 +467,8 @@ struct llm_graph_params {
|
|||
cvec == other.cvec &&
|
||||
loras == other.loras &&
|
||||
cross == other.cross &&
|
||||
update_mtp_kv == other.update_mtp_kv &&
|
||||
use_mtp_head == other.use_mtp_head &&
|
||||
n_outputs == other.n_outputs;
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -13787,168 +13787,204 @@ struct llm_build_glm4 : public llm_graph_context {
|
|||
};
|
||||
|
||||
struct llm_build_glm4_moe : public llm_graph_context {
|
||||
llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params, bool build_mtp_path)
|
||||
: llm_graph_context(params) {
|
||||
llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
// Only process up to last layer (skip final NextN layer)
|
||||
// Final layer tensors are loaded but not processed in forward pass
|
||||
const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
|
||||
for (int il = 0; il < n_transformer_layers; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
// Pre-attention norm
|
||||
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
if (model.layers[il].bq) {
|
||||
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||
}
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
if (model.layers[il].bk) {
|
||||
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||
}
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
if (model.layers[il].bv) {
|
||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
}
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
// Apply Q/K norm if available (GLM-4.5 355B variant)
|
||||
if (model.layers[il].attn_q_norm) {
|
||||
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
}
|
||||
if (model.layers[il].attn_k_norm) {
|
||||
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Kcur, "Kcur_normed", il);
|
||||
}
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_transformer_layers - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// Post-attention norm
|
||||
cur = build_norm(ffn_inp, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "post_attn_norm", il);
|
||||
|
||||
// Check if this is a dense layer (n_layer_dense_lead=1, so layer 0 is dense)
|
||||
if (static_cast<uint32_t>(il) < hparams.n_layer_dense_lead) {
|
||||
// Dense FFN layer
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
} else {
|
||||
// Process routed experts using existing MoE infrastructure
|
||||
ggml_tensor * routed_out = build_moe_ffn(cur,
|
||||
model.layers[il].ffn_gate_inp,
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
model.layers[il].ffn_down_exps,
|
||||
model.layers[il].ffn_exp_probs_b,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, hparams.expert_weights_norm,
|
||||
true, hparams.expert_weights_scale,
|
||||
(llama_expert_gating_func_type) hparams.expert_gating_func,
|
||||
il);
|
||||
cb(routed_out, "ffn_moe_out", il);
|
||||
|
||||
// Process shared expert on original input
|
||||
ggml_tensor * shared_out = build_ffn(cur,
|
||||
model.layers[il].ffn_up_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_gate_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_down_shexp, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(shared_out, "ffn_shexp_out", il);
|
||||
|
||||
// Final output: routed_output + shared_output
|
||||
cur = ggml_add(ctx0, routed_out, shared_out);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
LLAMA_LOG_WARN(
|
||||
"[DEBUG-GRAPH-STATE] Building graph. MTP Head=%s, MTP KV Update=%s, n_tokens=%d\n",
|
||||
params.use_mtp_head ? "true" : "false",
|
||||
params.update_mtp_kv ? "true" : "false",
|
||||
n_tokens
|
||||
);
|
||||
// for (int i = 0; i < n_tokens; ++i) {
|
||||
// LLAMA_LOG_WARN(" - ubatch token[%d]: ID=%d, Pos=%d\n", i, ubatch.token[i], ubatch.pos[i]);
|
||||
// }
|
||||
if (n_tokens > 0) {
|
||||
LLAMA_LOG_WARN(
|
||||
" - ubatch tokens: [ID=%d, Pos=%d] ... [ID=%d, Pos=%d]\n",
|
||||
ubatch.token[0], ubatch.pos[0],
|
||||
ubatch.token[n_tokens-1], ubatch.pos[n_tokens-1]
|
||||
);
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
|
||||
if (params.use_mtp_head) {
|
||||
ggml_tensor* hidden_states_from_main_model;
|
||||
|
||||
// cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
if (params.update_mtp_kv) {
|
||||
hidden_states_from_main_model = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens);
|
||||
ggml_set_name(hidden_states_from_main_model, "result_embd_pooled");
|
||||
ggml_set_input(hidden_states_from_main_model);
|
||||
|
||||
if (build_mtp_path) {
|
||||
const int il_mtp = hparams.n_layer - 1;
|
||||
const auto & mtp_layer = model.layers[il_mtp];
|
||||
|
||||
ggml_tensor * mtp_logits = build_mtp_tail(mtp_layer, cur, n_embd_head);
|
||||
res->t_logits = mtp_logits;
|
||||
auto inp_mtp = std::make_unique<llm_graph_input_mtp_states>();
|
||||
inp_mtp->states = hidden_states_from_main_model;
|
||||
res->add_input(std::move(inp_mtp));
|
||||
} else {
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
res->t_logits = cur;
|
||||
hidden_states_from_main_model = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, hparams.n_embd);
|
||||
ggml_set_name(hidden_states_from_main_model, "result_embd_pooled");
|
||||
ggml_set_input(hidden_states_from_main_model);
|
||||
|
||||
auto inp_mtp = std::make_unique<llm_graph_input_mtp_states>();
|
||||
inp_mtp->states = hidden_states_from_main_model;
|
||||
res->add_input(std::move(inp_mtp));
|
||||
}
|
||||
res->t_embd = hidden_states_from_main_model;
|
||||
|
||||
const int il_mtp = hparams.n_layer - 1;
|
||||
const auto & mtp_layer = model.layers[il_mtp];
|
||||
res->t_logits = build_mtp_tail(mtp_layer, hidden_states_from_main_model, n_embd_head);
|
||||
|
||||
} else {
|
||||
ggml_tensor * inpL = build_inp_embd(model.tok_embd);
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
// Only process up to last layer (skip final NextN layer)
|
||||
// Final layer tensors are loaded but not processed in forward pass
|
||||
const int n_transformer_layers = n_layer - hparams.nextn_predict_layers;
|
||||
for (int il = 0; il < n_transformer_layers; ++il) {
|
||||
// if (params.use_mtp_head) {
|
||||
// LLAMA_LOG_ERROR("[DEBUG-KV-ERROR] MTP path is running the main layer %d!\n", il);
|
||||
// } else {
|
||||
// LLAMA_LOG_WARN("[DEBUG-KV] Main Head Path: Accessing layer %d\n", il);
|
||||
// }
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
// Pre-attention norm
|
||||
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
if (model.layers[il].bq) {
|
||||
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||
}
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
if (model.layers[il].bk) {
|
||||
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||
}
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
if (model.layers[il].bv) {
|
||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
}
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
// Apply Q/K norm if available (GLM-4.5 355B variant)
|
||||
if (model.layers[il].attn_q_norm) {
|
||||
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
}
|
||||
if (model.layers[il].attn_k_norm) {
|
||||
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Kcur, "Kcur_normed", il);
|
||||
}
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_transformer_layers - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// Post-attention norm
|
||||
cur = build_norm(ffn_inp, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "post_attn_norm", il);
|
||||
|
||||
// Check if this is a dense layer (n_layer_dense_lead=1, so layer 0 is dense)
|
||||
if (static_cast<uint32_t>(il) < hparams.n_layer_dense_lead) {
|
||||
// Dense FFN layer
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
} else {
|
||||
// Process routed experts using existing MoE infrastructure
|
||||
ggml_tensor * routed_out = build_moe_ffn(cur,
|
||||
model.layers[il].ffn_gate_inp,
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
model.layers[il].ffn_down_exps,
|
||||
model.layers[il].ffn_exp_probs_b,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, hparams.expert_weights_norm,
|
||||
true, hparams.expert_weights_scale,
|
||||
(llama_expert_gating_func_type) hparams.expert_gating_func,
|
||||
il);
|
||||
cb(routed_out, "ffn_moe_out", il);
|
||||
|
||||
// Process shared expert on original input
|
||||
ggml_tensor * shared_out = build_ffn(cur,
|
||||
model.layers[il].ffn_up_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_gate_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_down_shexp, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(shared_out, "ffn_shexp_out", il);
|
||||
|
||||
// Final output: routed_output + shared_output
|
||||
cur = ggml_add(ctx0, routed_out, shared_out);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
|
||||
|
||||
// cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
// Use the main model header
|
||||
res->t_logits = build_lora_mm(model.output, cur);
|
||||
}
|
||||
|
||||
ggml_build_forward_expand(gf, res->t_logits);
|
||||
ggml_build_forward_expand(gf, res->t_logits);
|
||||
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -13956,6 +13992,10 @@ private:
|
|||
int64_t n_embd_head
|
||||
) {
|
||||
const int il = hparams.n_layer - 1;
|
||||
// LLAMA_LOG_WARN("[DEBUG-KV] MTP Head Path: Accessing layer %d\n", il);
|
||||
ggml_tensor * sum_node = ggml_sum(ctx0, prev_embeddings);
|
||||
|
||||
ggml_set_name(sum_node, "mtp_input_sum");
|
||||
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
|
@ -14015,7 +14055,11 @@ private:
|
|||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
// LLAMA_LOG_WARN("[DEBUG-MTP-ATTN] Inputs for build_attn in the layer %d:\n", il);
|
||||
// LLAMA_LOG_WARN(" - Qcur shape: [%d, %d, %d]\n", Qcur->ne[0], Qcur->ne[1], Qcur->ne[2]);
|
||||
// LLAMA_LOG_WARN(" - Kcur shape: [%d, %d, %d]\n", Kcur->ne[0], Kcur->ne[1], Kcur->ne[2]);
|
||||
// LLAMA_LOG_WARN(" - Vcur shape: [%d, %d, %d]\n", Vcur->ne[0], Vcur->ne[1], Vcur->ne[2]);
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
mtp_layer.wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
|
|
@ -18511,7 +18555,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
|||
} break;
|
||||
case LLM_ARCH_GLM4_MOE:
|
||||
{
|
||||
llm = std::make_unique<llm_build_glm4_moe>(*this, params, build_mtp);
|
||||
llm = std::make_unique<llm_build_glm4_moe>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_BITNET:
|
||||
{
|
||||
|
|
|
|||
|
|
@ -3387,6 +3387,15 @@ struct server_context {
|
|||
slot.n_prompt_tokens_processed += n_pos;
|
||||
}
|
||||
|
||||
const size_t n_to_log = slot.mtp_kv_update_batch.size();
|
||||
if (n_to_log > 0) {
|
||||
SLT_INF(slot,
|
||||
"DEBUG-KV-REQ Cache Warm-up: Requesting KV update for %zu tokens. Positions: %d ... %d\n",
|
||||
n_to_log,
|
||||
slot.mtp_kv_update_batch.front().n_past,
|
||||
slot.mtp_kv_update_batch.back().n_past
|
||||
);
|
||||
}
|
||||
// add prompt tokens for processing in the current batch
|
||||
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
|
||||
// get next token to process
|
||||
|
|
@ -3517,12 +3526,12 @@ struct server_context {
|
|||
continue; // continue loop of n_batch
|
||||
}
|
||||
|
||||
for (auto & slot : slots) {
|
||||
// This should only trigger on a non-empty update batch once, after prompt processing but not during token generation
|
||||
if (slot.has_mtp) {
|
||||
mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, i, n_tokens);
|
||||
}
|
||||
}
|
||||
// for (auto & slot : slots) {
|
||||
// // This should only trigger on a non-empty update batch once, after prompt processing but not during token generation
|
||||
// if (slot.has_mtp) {
|
||||
// mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch, i, n_tokens);
|
||||
// }
|
||||
// }
|
||||
|
||||
// move the head of the batch forward with the number of tokens we just processed
|
||||
i_next = i + n_tokens;
|
||||
|
|
|
|||
Loading…
Reference in New Issue