mtp-batch (wip): move mtp execution to batch format
This commit is contained in:
parent
c6237c71ff
commit
1318b2de82
|
|
@ -374,47 +374,54 @@ llama_token mtp_speculative_gen_draft(
|
|||
return -1;
|
||||
}
|
||||
|
||||
llama_batch batch = llama_batch_init(1, 0, 1);
|
||||
common_batch_add(batch, id_last, n_past, {0}, true);
|
||||
llama_batch mtp_batch = llama_batch_init(1, 0, 1);
|
||||
common_batch_add(mtp_batch, id_last, n_past, {0}, true);
|
||||
mtp_batch.update_mtp_kv = true;
|
||||
|
||||
llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx);
|
||||
llama_decode(ctx, mtp_batch);
|
||||
llama_batch_free(mtp_batch);
|
||||
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
const int n_vocab = llama_n_vocab(vocab);
|
||||
|
||||
llama_token_data_array * cur_p = common_sampler_get_candidates(smpl);
|
||||
|
||||
cur_p->size = n_vocab;
|
||||
for (int i = 0; i < n_vocab; ++i) {
|
||||
cur_p->data[i].id = i;
|
||||
cur_p->data[i].logit = llama_get_logits_ith(ctx, last_tok_idx)[i];
|
||||
cur_p->data[i].logit = llama_get_logits_ith(ctx, 0)[i]; // TODO: check if position 0 is the right
|
||||
}
|
||||
cur_p->sorted = false;
|
||||
|
||||
common_sampler_apply_chain(smpl, cur_p);
|
||||
|
||||
const llama_token id = cur_p->data[0].id;
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
return id;
|
||||
|
||||
return cur_p->data[0].id;
|
||||
}
|
||||
|
||||
|
||||
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens, size_t batch_start, size_t n_tokens) {
|
||||
mtp_kv_update_data token;
|
||||
|
||||
if (tokens.empty()) {
|
||||
tokens.clear();
|
||||
return;
|
||||
}
|
||||
if (n_tokens < 0) {
|
||||
n_tokens = tokens.size();
|
||||
}
|
||||
const size_t n_to_process = std::min((size_t)tokens.size(), n_tokens);
|
||||
|
||||
for (int i = 0; i < std::min(tokens.size(), n_tokens); ++i) {
|
||||
token = tokens[i];
|
||||
//fprintf(stderr, "updating mtp kv cache with token (%d, %d, %d)\n", token.id, token.n_past, (int) (token.tok_idx - batch_start));
|
||||
|
||||
mtp_speculative_gen_draft(nullptr, ctx, token.id, token.n_past, token.tok_idx - batch_start);
|
||||
LOG_DBG(
|
||||
"[MTP BATCHING] mtp_update_kv_cache call for %zu tokens.\n",
|
||||
n_to_process
|
||||
);
|
||||
llama_batch mtp_batch = llama_batch_init(n_to_process, 0, 1);
|
||||
|
||||
for (size_t i = 0; i < n_to_process; ++i) {
|
||||
const mtp_kv_update_data& token_data = tokens[i];
|
||||
common_batch_add(mtp_batch, token_data.id, token_data.n_past, {0}, false);
|
||||
}
|
||||
|
||||
mtp_batch.update_mtp_kv = true;
|
||||
|
||||
llama_decode(ctx, mtp_batch);
|
||||
|
||||
llama_batch_free(mtp_batch);
|
||||
tokens.clear();
|
||||
}
|
||||
|
|
@ -230,6 +230,7 @@ extern "C" {
|
|||
int32_t * n_seq_id;
|
||||
llama_seq_id ** seq_id;
|
||||
int8_t * logits; // TODO: rename this to "output"
|
||||
bool update_mtp_kv;
|
||||
} llama_batch;
|
||||
|
||||
enum llama_model_kv_override_type {
|
||||
|
|
@ -1454,8 +1455,8 @@ extern "C" {
|
|||
ggml_opt_epoch_callback callback_train,
|
||||
ggml_opt_epoch_callback callback_eval);
|
||||
|
||||
LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
|
||||
const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx);
|
||||
// LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
|
||||
// const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
|||
|
|
@ -834,13 +834,14 @@ struct llama_batch llama_batch_get_one(
|
|||
|
||||
struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
|
||||
llama_batch batch = {
|
||||
/*n_tokens =*/ 0,
|
||||
/*tokens =*/ nullptr,
|
||||
/*embd =*/ nullptr,
|
||||
/*pos =*/ nullptr,
|
||||
/*n_seq_id =*/ nullptr,
|
||||
/*seq_id =*/ nullptr,
|
||||
/*logits =*/ nullptr,
|
||||
/*n_tokens =*/ 0,
|
||||
/*tokens =*/ nullptr,
|
||||
/*embd =*/ nullptr,
|
||||
/*pos =*/ nullptr,
|
||||
/*n_seq_id =*/ nullptr,
|
||||
/*seq_id =*/ nullptr,
|
||||
/*logits =*/ nullptr,
|
||||
/*update_mtp_kv =*/ false,
|
||||
};
|
||||
|
||||
if (embd) {
|
||||
|
|
|
|||
|
|
@ -1070,6 +1070,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
};
|
||||
|
||||
int64_t n_outputs_prev = 0;
|
||||
const bool do_mtp_kv_update = batch_inp.update_mtp_kv;
|
||||
|
||||
do {
|
||||
const auto & ubatch = mctx->get_ubatch();
|
||||
|
|
@ -1129,6 +1130,39 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
||||
//}
|
||||
|
||||
if (do_mtp_kv_update) {
|
||||
LLAMA_LOG_INFO(
|
||||
"[MTP BATCHING] Processando MTP KV update para um ubatch de %u tokens.\n",
|
||||
ubatch.n_tokens
|
||||
);
|
||||
auto res_mtp = std::make_unique<llm_graph_result>(graph_max_nodes());
|
||||
|
||||
auto params_mtp = mtp_graph_params(res_mtp.get(), ubatch, mctx.get());
|
||||
ggml_backend_sched_t sched_mtp = params_mtp.sched;
|
||||
|
||||
auto * gf_mtp = model.build_mtp_graph(params_mtp);
|
||||
if (gf_mtp) {
|
||||
ggml_backend_sched_alloc_graph(sched_mtp, gf_mtp);
|
||||
|
||||
ggml_tensor* prev_embedding_tensor = res->get_embd();
|
||||
ggml_tensor* embd_input_mtp = ggml_get_tensor(res_mtp->get_ctx(), "mtp_prev_embeddings_batch_input");
|
||||
|
||||
// ggml_backend_tensor_set(embd_input_mtp, prev_embedding_tensor->data, 0, ggml_nbytes(prev_embedding_tensor));
|
||||
ggml_backend_tensor_copy(prev_embedding_tensor, embd_input_mtp);
|
||||
|
||||
ggml_backend_sched_graph_compute(sched_mtp, gf_mtp);
|
||||
|
||||
if (ubatch.output[0]) {
|
||||
struct ggml_tensor * logits_mtp = res_mtp->get_logits();
|
||||
if (logits_mtp) {
|
||||
float * logits_dest = logits + n_outputs_prev * n_vocab;
|
||||
ggml_backend_tensor_get(logits_mtp, logits_dest, 0, ggml_nbytes(logits_mtp));
|
||||
}
|
||||
}
|
||||
}
|
||||
ggml_backend_sched_free(sched_mtp);
|
||||
}
|
||||
|
||||
auto * t_logits = res->get_logits();
|
||||
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
||||
embd_tensor = res->get_embd();
|
||||
|
|
@ -2995,79 +3029,79 @@ void llama_opt_epoch(
|
|||
callback_eval);
|
||||
}
|
||||
|
||||
void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
|
||||
const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) {
|
||||
// void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
|
||||
// const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) {
|
||||
|
||||
const auto * model = llama_get_model(ctx);
|
||||
// const auto * model = llama_get_model(ctx);
|
||||
|
||||
auto res_mtp = std::make_unique<llm_graph_result>(ctx->graph_max_nodes());
|
||||
std::unique_ptr<llama_memory_context_i> mctx = ctx->mtp_memory_batch(batch_inp);
|
||||
// auto res_mtp = std::make_unique<llm_graph_result>(ctx->graph_max_nodes());
|
||||
// std::unique_ptr<llama_memory_context_i> mctx = ctx->mtp_memory_batch(batch_inp);
|
||||
|
||||
std::vector<uint32_t> idxs;
|
||||
idxs.push_back(n_past);
|
||||
llama_kv_cache_unified::slot_info sinfo = {
|
||||
/*.s0 =*/ 0,
|
||||
/*.s1 =*/ 0,
|
||||
/*.strm =*/ { 0 },
|
||||
/*.idxs =*/ { idxs },
|
||||
};
|
||||
llama_kv_cache_unified::slot_info_vec_t sinfos;
|
||||
sinfos.push_back(sinfo);
|
||||
// std::vector<uint32_t> idxs;
|
||||
// idxs.push_back(n_past);
|
||||
// llama_kv_cache_unified::slot_info sinfo = {
|
||||
// /*.s0 =*/ 0,
|
||||
// /*.s1 =*/ 0,
|
||||
// /*.strm =*/ { 0 },
|
||||
// /*.idxs =*/ { idxs },
|
||||
// };
|
||||
// llama_kv_cache_unified::slot_info_vec_t sinfos;
|
||||
// sinfos.push_back(sinfo);
|
||||
|
||||
static_cast<llama_kv_cache_unified_context*>(mctx.get())->set_sinfos(sinfos);
|
||||
const auto& ubatch_mtp = mctx->get_ubatch();
|
||||
// static_cast<llama_kv_cache_unified_context*>(mctx.get())->set_sinfos(sinfos);
|
||||
// const auto& ubatch_mtp = mctx->get_ubatch();
|
||||
|
||||
//llama_ubatch ubatch_mtp;
|
||||
//ubatch_mtp.n_tokens = 1;
|
||||
//ubatch_mtp.pos = &n_past;
|
||||
// //llama_ubatch ubatch_mtp;
|
||||
// //ubatch_mtp.n_tokens = 1;
|
||||
// //ubatch_mtp.pos = &n_past;
|
||||
|
||||
auto params_mtp = std::make_unique<llm_graph_params>(ctx->mtp_graph_params(res_mtp.get(), ubatch_mtp, mctx.get()));
|
||||
ggml_backend_sched_t sched = params_mtp->sched;
|
||||
// auto params_mtp = std::make_unique<llm_graph_params>(ctx->mtp_graph_params(res_mtp.get(), ubatch_mtp, mctx.get()));
|
||||
// ggml_backend_sched_t sched = params_mtp->sched;
|
||||
|
||||
auto * last_embd = ctx->get_embeddings_ith(last_tok_idx);
|
||||
// auto * last_embd = ctx->get_embeddings_ith(last_tok_idx);
|
||||
|
||||
//if (mctx && !mctx->set_n_kv()) {
|
||||
// LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
|
||||
//}
|
||||
static_cast<llama_kv_cache_unified_context*>(mctx.get())->set_n_kv();
|
||||
// //if (mctx && !mctx->set_n_kv()) {
|
||||
// // LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
|
||||
// //}
|
||||
// static_cast<llama_kv_cache_unified_context*>(mctx.get())->set_n_kv();
|
||||
|
||||
auto * gf = model->build_mtp_graph(*params_mtp, last_token_id, n_past);
|
||||
// auto * gf = model->build_mtp_graph(*params_mtp);
|
||||
|
||||
if (!gf) {
|
||||
LLAMA_LOG_ERROR("%s: ERROR - The construction of the MTP graph failed (returned null).", __func__);
|
||||
if (sched) ggml_backend_sched_free(sched);
|
||||
return;
|
||||
}
|
||||
// if (!gf) {
|
||||
// LLAMA_LOG_ERROR("%s: ERROR - The construction of the MTP graph failed (returned null).", __func__);
|
||||
// if (sched) ggml_backend_sched_free(sched);
|
||||
// return;
|
||||
// }
|
||||
|
||||
ggml_backend_sched_reset(sched); // clear the allocation of the previous graph
|
||||
ggml_backend_sched_alloc_graph(sched, gf); // explicitly allocate the new graph but do not execute it
|
||||
// ggml_backend_sched_reset(sched); // clear the allocation of the previous graph
|
||||
// ggml_backend_sched_alloc_graph(sched, gf); // explicitly allocate the new graph but do not execute it
|
||||
|
||||
ggml_tensor * mtp_token_id_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_token_id_input");
|
||||
ggml_backend_tensor_set(mtp_token_id_input, &last_token_id, 0, sizeof(last_token_id)); // copy data to the newly allocated graph tensors
|
||||
// ggml_tensor * mtp_token_id_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_token_id_input");
|
||||
// ggml_backend_tensor_set(mtp_token_id_input, &last_token_id, 0, sizeof(last_token_id)); // copy data to the newly allocated graph tensors
|
||||
|
||||
ggml_tensor * mtp_prev_embedding_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_prev_embedding_input");
|
||||
ggml_backend_tensor_set(mtp_prev_embedding_input, last_embd, 0, ggml_nbytes(mtp_prev_embedding_input)); // copy data to the newly allocated graph tensors
|
||||
// ggml_tensor * mtp_prev_embedding_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_prev_embedding_input");
|
||||
// ggml_backend_tensor_set(mtp_prev_embedding_input, last_embd, 0, ggml_nbytes(mtp_prev_embedding_input)); // copy data to the newly allocated graph tensors
|
||||
|
||||
ggml_backend_sched_graph_compute(sched, gf); // execute the graph
|
||||
// ggml_backend_sched_graph_compute(sched, gf); // execute the graph
|
||||
|
||||
struct ggml_tensor * logits_mtp = res_mtp->get_logits();
|
||||
// struct ggml_tensor * logits_mtp = res_mtp->get_logits();
|
||||
|
||||
if (logits_mtp) {
|
||||
float * logits_dest = ctx->get_logits_ith(last_tok_idx);
|
||||
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched, logits_mtp);
|
||||
if (backend_res) {
|
||||
// ggml_backend_tensor_get is the function for GPU->CPU copies.
|
||||
// We are copying a single 32-bit integer.
|
||||
ggml_backend_tensor_get(logits_mtp,
|
||||
logits_dest, // Pointer to our C++ variable
|
||||
0, // Starting offset in bytes
|
||||
ggml_nbytes(logits_mtp)); // Number of bytes to copy
|
||||
} else {
|
||||
LLAMA_LOG_ERROR("%s: ERROR - Could not obtain the backend for the logits tensor.", __func__);
|
||||
}
|
||||
} else {
|
||||
LLAMA_LOG_WARN("%s: WARNING - The MTP graph did not produce a logit tensor.", __func__);
|
||||
}
|
||||
// if (logits_mtp) {
|
||||
// float * logits_dest = ctx->get_logits_ith(last_tok_idx);
|
||||
// ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched, logits_mtp);
|
||||
// if (backend_res) {
|
||||
// // ggml_backend_tensor_get is the function for GPU->CPU copies.
|
||||
// // We are copying a single 32-bit integer.
|
||||
// ggml_backend_tensor_get(logits_mtp,
|
||||
// logits_dest, // Pointer to our C++ variable
|
||||
// 0, // Starting offset in bytes
|
||||
// ggml_nbytes(logits_mtp)); // Number of bytes to copy
|
||||
// } else {
|
||||
// LLAMA_LOG_ERROR("%s: ERROR - Could not obtain the backend for the logits tensor.", __func__);
|
||||
// }
|
||||
// } else {
|
||||
// LLAMA_LOG_WARN("%s: WARNING - The MTP graph did not produce a logit tensor.", __func__);
|
||||
// }
|
||||
|
||||
ggml_backend_sched_free(sched);
|
||||
}
|
||||
// ggml_backend_sched_free(sched);
|
||||
// }
|
||||
|
|
@ -1074,6 +1074,26 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
|||
return cur;
|
||||
}
|
||||
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_embd_mtp(ggml_tensor * mtp_tok_embd) const {
|
||||
auto inp = std::make_unique<llm_graph_input_embd>();
|
||||
ggml_tensor * cur = nullptr;
|
||||
|
||||
if (ubatch.token) {
|
||||
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
|
||||
ggml_set_name(inp->tokens, "mtp_inp_tokens");
|
||||
ggml_set_input(inp->tokens);
|
||||
|
||||
cur = ggml_get_rows(ctx0, mtp_tok_embd, inp->tokens);
|
||||
} else {
|
||||
GGML_ABORT("fatal error: MTP update expects token IDs, not embeddings");
|
||||
}
|
||||
|
||||
cb(cur, "mtp_inp_embd", -1);
|
||||
res->add_input(std::move(inp));
|
||||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_pos() const {
|
||||
auto inp = std::make_unique<llm_graph_input_pos>(hparams.n_pos_per_embd());
|
||||
|
||||
|
|
|
|||
|
|
@ -664,6 +664,7 @@ struct llm_graph_context {
|
|||
//
|
||||
|
||||
ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const;
|
||||
ggml_tensor * build_inp_embd_mtp(ggml_tensor * mtp_tok_embd) const;
|
||||
ggml_tensor * build_inp_pos() const;
|
||||
ggml_tensor * build_inp_attn_scale() const;
|
||||
ggml_tensor * build_inp_out_ids() const;
|
||||
|
|
|
|||
|
|
@ -13946,54 +13946,29 @@ struct llm_build_glm4_moe : public llm_graph_context {
|
|||
};
|
||||
|
||||
struct llm_build_glm4_moe_mtp : public llm_graph_context {
|
||||
llm_build_glm4_moe_mtp(const llama_model & model, const llm_graph_params & params,
|
||||
// For v0, let's rebuild the computational graph for every step + this mimics the vLLM impl parameterization
|
||||
llama_token last_token_id, int n_past
|
||||
) : llm_graph_context(params) {
|
||||
llm_build_glm4_moe_mtp(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
||||
// Assuming a single MTP layer at the end
|
||||
const int il = hparams.n_layer - 1;
|
||||
const auto & mtp_layer = model.layers[il];
|
||||
|
||||
// ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1);
|
||||
// ggml_set_i32(inp_pos, n_past);
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
//llm_graph_input_attn_no_cache * inp_attn = build_attn_inp_no_cache();//nullptr;
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
// get MTP embedding for last (conventionally sampled) token
|
||||
// ggml_tensor * inp_token_id = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1);
|
||||
// LLAMA_LOG_INFO("step: '%d'\n", 5641);
|
||||
// ggml_set_i32(inp_token_id, last_token_id);
|
||||
//ggml_set_no_alloc(ctx0, false);
|
||||
//LLAMA_LOG_INFO("last token id: '%d'\n", last_token_id);
|
||||
|
||||
ggml_tensor * inp_token_id = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, 1);
|
||||
ggml_set_name(inp_token_id, "mtp_token_id_input");
|
||||
ggml_set_input(inp_token_id);
|
||||
|
||||
//ggml_tensor * inp_token_id = ggml_new_i32(ctx0, last_token_id);
|
||||
//ggml_set_no_alloc(ctx0, true);
|
||||
ggml_tensor* prev_embeddings_batch = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, model.hparams.n_embd, n_tokens);
|
||||
ggml_set_name(prev_embeddings_batch, "mtp_prev_embeddings_batch_input");
|
||||
ggml_set_input(prev_embeddings_batch);
|
||||
|
||||
ggml_tensor * token_emb = ggml_get_rows(ctx0, mtp_layer.nextn.embed_tokens, inp_token_id);
|
||||
ggml_tensor * token_emb = build_inp_embd_mtp(mtp_layer.nextn.embed_tokens);
|
||||
|
||||
ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il);
|
||||
ggml_tensor * hidden_state_norm = build_norm(prev_embeddings_batch, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il);
|
||||
|
||||
ggml_tensor* prev_embedding_leaf = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, model.hparams.n_embd);
|
||||
ggml_set_name(prev_embedding_leaf, "mtp_prev_embedding_input");
|
||||
ggml_set_input(prev_embedding_leaf);
|
||||
ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0);
|
||||
|
||||
// vLLM l99 previous_hidden_states = self.hnorm(previous_hidden_states)
|
||||
ggml_tensor * hidden_state_norm = build_norm(prev_embedding_leaf, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il);
|
||||
//token_emb_norm = ggml_cont(ctx0, token_emb_norm);
|
||||
//hidden_state_norm = ggml_cont(ctx0, hidden_state_norm);
|
||||
|
||||
ggml_tensor * combined = ggml_concat(ctx0, token_emb_norm, hidden_state_norm, 0); // torch.cat
|
||||
|
||||
ggml_tensor* cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined); // eh_proj
|
||||
ggml_tensor* cur = build_lora_mm(mtp_layer.nextn.eh_proj, combined);
|
||||
|
||||
// now proceed through last layer (skipped in main model)
|
||||
ggml_tensor * inpSA = cur;
|
||||
|
|
@ -14090,11 +14065,11 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context {
|
|||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
cur = build_norm(cur, mtp_layer.nextn.shared_head_norm, NULL, LLM_NORM_RMS, il);
|
||||
cur = build_lora_mm(mtp_layer.nextn.shared_head_head, cur);
|
||||
|
||||
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, res->t_logits);
|
||||
}
|
||||
};
|
||||
|
|
@ -18689,14 +18664,13 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
|||
return llm->res->get_gf();
|
||||
}
|
||||
|
||||
ggml_cgraph * llama_model::build_mtp_graph(const llm_graph_params& params,
|
||||
llama_token last_token_id, int n_past) const {
|
||||
ggml_cgraph * llama_model::build_mtp_graph(const llm_graph_params& params) const {
|
||||
std::unique_ptr<llm_graph_context> llm;
|
||||
|
||||
switch (arch) {
|
||||
case LLM_ARCH_GLM4_MOE:
|
||||
{
|
||||
llm = std::make_unique<llm_build_glm4_moe_mtp>(*this, params, last_token_id, n_past);
|
||||
llm = std::make_unique<llm_build_glm4_moe_mtp>(*this, params);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
|
|
|
|||
|
|
@ -475,8 +475,7 @@ struct llama_model {
|
|||
|
||||
// TODO: move this to new llm_arch_model_i interface
|
||||
ggml_cgraph * build_graph(const llm_graph_params & params) const;
|
||||
ggml_cgraph * build_mtp_graph(const llm_graph_params & params,
|
||||
llama_token last_token_id, int n_past) const;
|
||||
ggml_cgraph * build_mtp_graph(const llm_graph_params& params) const;
|
||||
|
||||
private:
|
||||
struct impl;
|
||||
|
|
|
|||
Loading…
Reference in New Issue