added proper KV cache management for MTP layers and slightly refactored
This commit is contained in:
parent
6e9bafc7a7
commit
6870f9790c
|
|
@ -370,25 +370,45 @@ llama_token mtp_speculative_gen_draft(
|
|||
int32_t n_past,
|
||||
int32_t last_tok_idx) {
|
||||
|
||||
const auto * model = llama_get_model(ctx);
|
||||
auto * last_embd = llama_get_embeddings_tensor(ctx);
|
||||
llama_token token_data[] = { id_last };
|
||||
llama_pos pos_data[] = { n_past };
|
||||
int32_t n_seq_id_data[] = { 1 };
|
||||
llama_seq_id seq_id_data_internal[] = { 0 };
|
||||
llama_seq_id* seq_id_data[] = {seq_id_data_internal};
|
||||
int8_t logits_data[] = { (int8_t) (smpl != nullptr) };
|
||||
|
||||
GGML_ASSERT(model != nullptr);
|
||||
GGML_ASSERT(last_embd != nullptr);
|
||||
llama_build_and_execute_mtp_graph(ctx, last_embd, id_last, n_past, last_tok_idx);
|
||||
llama_batch batch = {
|
||||
/*.n_tokens = */ 1,
|
||||
/*.token = */ token_data,
|
||||
/*.embd = */ nullptr,
|
||||
/*.pos = */ pos_data,
|
||||
/*.n_seq_id = */ n_seq_id_data,
|
||||
/*.seq_id = */ seq_id_data,
|
||||
/*.logits = */ logits_data
|
||||
};
|
||||
|
||||
common_sampler_sample(smpl, ctx, last_tok_idx, true);
|
||||
llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx);
|
||||
//LOG_INF("updating kv cache for n_past: %d\n", n_past);
|
||||
|
||||
const auto* cur_p = common_sampler_get_candidates(smpl);
|
||||
/*LOG_INF("cur_p->size: %d\n", cur_p->size);
|
||||
if (!smpl) {
|
||||
return -1;
|
||||
}
|
||||
else {
|
||||
common_sampler_sample(smpl, ctx, last_tok_idx, true);
|
||||
const auto* cur_p = common_sampler_get_candidates(smpl);
|
||||
|
||||
//for (int k = 0; k < std::min(3, (int)cur_p->size); ++k) {
|
||||
// LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||
// k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
|
||||
//}
|
||||
|
||||
const llama_token id = cur_p->data[0].id;
|
||||
return id;
|
||||
}
|
||||
// LOG_INF("cur_p->size: %d\n", cur_p->size);
|
||||
|
||||
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
|
||||
LOG_INF(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||
k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
|
||||
}*/
|
||||
|
||||
// add drafted token for each sequence
|
||||
const llama_token id = cur_p->data[0].id;
|
||||
|
||||
// skip accepting draft token -- since we're only drafting one token this can't affect future outputs
|
||||
// smpl will accept the token if it doesn't get rejected by main model later
|
||||
|
|
@ -398,5 +418,15 @@ llama_token mtp_speculative_gen_draft(
|
|||
//result.reserve(1);
|
||||
//result.push_back(id);
|
||||
//return result;
|
||||
return id;
|
||||
}
|
||||
|
||||
|
||||
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens) {
|
||||
mtp_kv_update_data token;
|
||||
for (int i = 0; i < tokens.size(); ++i) {
|
||||
token = tokens[i];
|
||||
mtp_speculative_gen_draft(nullptr, ctx, token.id, token.n_past, token.tok_idx);
|
||||
}
|
||||
|
||||
tokens.clear();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,6 +12,12 @@ struct common_speculative_params {
|
|||
float p_min = 0.75f; // min probability required to accept a token in the draft
|
||||
};
|
||||
|
||||
struct mtp_kv_update_data {
|
||||
llama_token id;
|
||||
int32_t n_past;
|
||||
int32_t tok_idx;
|
||||
};
|
||||
|
||||
struct common_speculative * common_speculative_init(
|
||||
struct llama_context * ctx_tgt,
|
||||
struct llama_context * ctx_dft
|
||||
|
|
@ -42,3 +48,5 @@ llama_tokens common_speculative_gen_draft(
|
|||
struct common_speculative_params params,
|
||||
const llama_tokens & prompt,
|
||||
llama_token id_last);
|
||||
|
||||
void mtp_update_kv_cache(struct llama_context * ctx, std::vector<mtp_kv_update_data>& tokens);
|
||||
|
|
|
|||
|
|
@ -544,9 +544,6 @@ extern "C" {
|
|||
// Returns true if the model is diffusion-based (like LLaDA, Dream, etc.)
|
||||
LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model);
|
||||
|
||||
LLAMA_API ggml_cgraph * llama_build_mtp_graph(const struct llama_model * model, const struct llm_graph_params & params,
|
||||
struct ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past);
|
||||
|
||||
// Returns 0 on success
|
||||
LLAMA_API uint32_t llama_model_quantize(
|
||||
const char * fname_inp,
|
||||
|
|
@ -999,8 +996,6 @@ extern "C" {
|
|||
// otherwise: float[n_embd] (1-dimensional)
|
||||
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
|
||||
|
||||
LLAMA_API ggml_tensor * llama_get_embeddings_tensor(struct llama_context * ctx);
|
||||
|
||||
//
|
||||
// Vocab
|
||||
//
|
||||
|
|
@ -1459,16 +1454,8 @@ extern "C" {
|
|||
ggml_opt_epoch_callback callback_train,
|
||||
ggml_opt_epoch_callback callback_eval);
|
||||
|
||||
LLAMA_API llm_graph_params llama_mtp_graph_params(struct llama_context* ctx, class llm_graph_result * res, const struct llama_ubatch& ubatch);
|
||||
|
||||
LLAMA_API ggml_status llama_graph_compute(struct llama_context * ctx, struct ggml_cgraph * gf, bool batched);
|
||||
|
||||
LLAMA_API void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
|
||||
ggml_tensor* hidden_state_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx);
|
||||
|
||||
LLAMA_API ggml_tensor * llama_graph_result_get_logits(class llm_graph_result * res);
|
||||
|
||||
|
||||
const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
|||
|
|
@ -275,7 +275,9 @@ bool llama_batch_allocr::init(
|
|||
}
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
// TEMPORARILY DISABLING THIS SANITY CHECK
|
||||
// TODO: UNDO THIS IF IT WORKS
|
||||
/*if (!ok) {
|
||||
LLAMA_LOG_ERROR(
|
||||
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
|
||||
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
|
||||
|
|
@ -284,7 +286,7 @@ bool llama_batch_allocr::init(
|
|||
__func__, s, s, p0, s, seq_pos_min(s));
|
||||
|
||||
return false;
|
||||
}
|
||||
}*/
|
||||
}
|
||||
|
||||
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
|
||||
|
|
|
|||
|
|
@ -1448,8 +1448,9 @@ llm_graph_params llama_context::graph_params(
|
|||
}
|
||||
|
||||
llm_graph_params llama_context::mtp_graph_params(
|
||||
llm_graph_result* res,
|
||||
const llama_ubatch& ubatch) {
|
||||
llm_graph_result * res,
|
||||
const llama_ubatch& ubatch,
|
||||
const llama_memory_context_i * mctx) {
|
||||
size_t n_nodes = std::max<uint32_t>(1024u, 8u * 8u * (((model.hparams.nextn_predict_layers + 1) * model.n_tensors()) / model.hparams.n_layer));
|
||||
ggml_backend_sched_t temp_sched = create_temp_scheduler(n_nodes);
|
||||
return {
|
||||
|
|
@ -1462,7 +1463,7 @@ llm_graph_params llama_context::mtp_graph_params(
|
|||
/*.backend_cpu =*/ backend_cpu,
|
||||
/*.cvec =*/ &cvec,
|
||||
/*.loras =*/ &loras,
|
||||
/*.mctx =*/ memory->init_batch(*balloc, 1, false).get(),
|
||||
/*.mctx =*/ mctx,
|
||||
/*.cross =*/ &cross,
|
||||
/*.n_outputs =*/ 1,
|
||||
/*.cb =*/ graph_get_cb(temp_sched),
|
||||
|
|
@ -1470,6 +1471,21 @@ llm_graph_params llama_context::mtp_graph_params(
|
|||
};
|
||||
}
|
||||
|
||||
std::unique_ptr<llama_memory_context_i> llama_context::mtp_memory_batch(const llama_batch& batch_inp) {
|
||||
const auto& vocab = model.vocab;
|
||||
const auto& hparams = model.hparams;
|
||||
|
||||
const int64_t n_vocab = vocab.n_tokens();
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, false)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return memory->init_batch(*balloc, 1, false);
|
||||
}
|
||||
|
||||
ggml_status llama_context::graph_compute(
|
||||
ggml_cgraph * gf,
|
||||
bool batched) {
|
||||
|
|
@ -2481,13 +2497,6 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
|
|||
return ctx->get_embeddings_seq(seq_id);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_get_embeddings_tensor(llama_context * ctx) {
|
||||
ctx->synchronize();
|
||||
|
||||
return ctx->get_embeddings_tensor();
|
||||
}
|
||||
|
||||
|
||||
// llama adapter API
|
||||
|
||||
int32_t llama_set_adapter_lora(
|
||||
|
|
@ -2985,42 +2994,43 @@ void llama_opt_epoch(
|
|||
callback_eval);
|
||||
}
|
||||
|
||||
llm_graph_params llama_mtp_graph_params(llama_context* ctx, llm_graph_result* res, const llama_ubatch& ubatch) {
|
||||
return ctx->mtp_graph_params(res, ubatch);
|
||||
}
|
||||
|
||||
|
||||
ggml_status llama_graph_compute(llama_context* ctx, ggml_cgraph* gf, bool batched) {
|
||||
return ctx->graph_compute(gf, batched);
|
||||
}
|
||||
|
||||
void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
|
||||
ggml_tensor * hidden_state_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) {
|
||||
const llama_batch batch_inp, llama_token last_token_id, int32_t n_past, int32_t last_tok_idx) {
|
||||
|
||||
const auto * model = llama_get_model(ctx);
|
||||
|
||||
auto res_mtp = std::make_unique<llm_graph_result>(ctx->graph_max_nodes());
|
||||
llama_memory_context_ptr mctx = ctx->mtp_memory_batch(batch_inp);
|
||||
const auto& ubatch_mtp = mctx->get_ubatch();
|
||||
|
||||
llama_ubatch ubatch_mtp;
|
||||
ubatch_mtp.n_tokens = 1;
|
||||
ubatch_mtp.pos = &n_past;
|
||||
|
||||
auto params_mtp = std::make_unique<llm_graph_params>(ctx->mtp_graph_params(res_mtp.get(), ubatch_mtp));
|
||||
|
||||
auto* gf = model->build_mtp_graph(*params_mtp, hidden_state_inp, last_token_id, n_past);
|
||||
//llama_ubatch ubatch_mtp;
|
||||
//ubatch_mtp.n_tokens = 1;
|
||||
//ubatch_mtp.pos = &n_past;
|
||||
|
||||
auto params_mtp = std::make_unique<llm_graph_params>(ctx->mtp_graph_params(res_mtp.get(), ubatch_mtp, mctx.get()));
|
||||
ggml_backend_sched_t sched = params_mtp->sched;
|
||||
|
||||
auto * last_embd = ctx->get_embeddings_ith(last_tok_idx);
|
||||
|
||||
if (mctx && !mctx->apply()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
|
||||
}
|
||||
|
||||
auto * gf = model->build_mtp_graph(*params_mtp, last_token_id, n_past);
|
||||
|
||||
ggml_backend_sched_reset(sched); // clear the allocation of the previous graph
|
||||
ggml_backend_sched_alloc_graph(sched, gf); // explicitly allocate the new graph but do not execute it
|
||||
|
||||
ggml_tensor * mtp_token_id_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_token_id_input");
|
||||
|
||||
ggml_backend_tensor_set(mtp_token_id_input, &last_token_id, 0, sizeof(last_token_id)); // copy data to the newly allocated graph tensors
|
||||
|
||||
ggml_tensor * mtp_prev_embedding_input = ggml_get_tensor(res_mtp->get_ctx(), "mtp_prev_embedding_input");
|
||||
ggml_backend_tensor_set(mtp_prev_embedding_input, last_embd, 0, ggml_nbytes(mtp_prev_embedding_input)); // copy data to the newly allocated graph tensors
|
||||
|
||||
ggml_backend_sched_graph_compute(sched, gf); // execute the graph
|
||||
|
||||
struct ggml_tensor * logits_mtp = res_mtp->get_logits();;
|
||||
LLAMA_LOG_INFO("logits_mtp pointer address: %p\n", (void*)logits_mtp);
|
||||
//LLAMA_LOG_INFO("logits_mtp pointer address: %p\n", (void*)logits_mtp);
|
||||
|
||||
if (logits_mtp) {
|
||||
ctx->set_logits_ith(logits_mtp, sched, last_tok_idx);
|
||||
|
|
|
|||
|
|
@ -200,12 +200,14 @@ public:
|
|||
// reserve a graph with a dummy ubatch of the specified size
|
||||
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
|
||||
|
||||
llm_graph_params mtp_graph_params(llm_graph_result * res, const llama_ubatch & ubatch);
|
||||
llm_graph_params mtp_graph_params(llm_graph_result * res, const llama_ubatch & ubatch, const llama_memory_context_i * mctx);
|
||||
|
||||
void set_logits_ith(struct ggml_tensor * logit_override, ggml_backend_sched_t sched_override, int32_t i);
|
||||
|
||||
ggml_backend_sched_t create_temp_scheduler(size_t n_nodes);
|
||||
|
||||
std::unique_ptr<llama_memory_context_i> mtp_memory_batch(const llama_batch& batch_inp);
|
||||
|
||||
private:
|
||||
llm_graph_params graph_params(
|
||||
llm_graph_result * res,
|
||||
|
|
|
|||
|
|
@ -1911,7 +1911,3 @@ int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buck
|
|||
|
||||
return relative_bucket;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_graph_result_get_logits(llm_graph_result * res) {
|
||||
return res->get_logits();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|||
}
|
||||
if (model.arch == LLM_ARCH_GLM4_MOE) {
|
||||
// GLM-4.5: Only process up to last layer, skip final NextN layer
|
||||
n_layer_cache = hparams.n_layer - hparams.nextn_predict_layers;
|
||||
n_layer_cache = hparams.n_layer;// - hparams.nextn_predict_layers;
|
||||
}
|
||||
|
||||
// create a context for each buffer type
|
||||
|
|
|
|||
|
|
@ -13948,7 +13948,7 @@ struct llm_build_glm4_moe : public llm_graph_context {
|
|||
struct llm_build_glm4_moe_mtp : public llm_graph_context {
|
||||
llm_build_glm4_moe_mtp(const llama_model & model, const llm_graph_params & params,
|
||||
// For v0, let's rebuild the computational graph for every step + this mimics the vLLM impl parameterization
|
||||
ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past
|
||||
llama_token last_token_id, int n_past
|
||||
) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
|
@ -13961,7 +13961,8 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context {
|
|||
// ggml_set_i32(inp_pos, n_past);
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
llm_graph_input_attn_no_cache * inp_attn = build_attn_inp_no_cache();//nullptr;
|
||||
//llm_graph_input_attn_no_cache * inp_attn = build_attn_inp_no_cache();//nullptr;
|
||||
auto * inp_attn = build_attn_inp_kv_unified();
|
||||
|
||||
ggml_tensor * cur;
|
||||
|
||||
|
|
@ -13982,9 +13983,9 @@ struct llm_build_glm4_moe_mtp : public llm_graph_context {
|
|||
ggml_tensor * token_emb = ggml_get_rows(ctx0, mtp_layer.nextn.embed_tokens, inp_token_id);
|
||||
ggml_tensor * token_emb_norm = build_norm(token_emb, mtp_layer.nextn.enorm, NULL, LLM_NORM_RMS, il);
|
||||
|
||||
ggml_tensor * prev_embedding_leaf = ggml_dup_tensor(ctx0, hidden_state_inp);
|
||||
ggml_set_name(prev_embedding_leaf, "mtp_prev_embedding_leaf");
|
||||
ggml_cpy(ctx0, hidden_state_inp, prev_embedding_leaf);
|
||||
ggml_tensor* prev_embedding_leaf = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, model.hparams.n_embd);
|
||||
ggml_set_name(prev_embedding_leaf, "mtp_prev_embedding_input");
|
||||
ggml_set_input(prev_embedding_leaf);
|
||||
|
||||
// vLLM l99 previous_hidden_states = self.hnorm(previous_hidden_states)
|
||||
ggml_tensor * hidden_state_norm = build_norm(prev_embedding_leaf, mtp_layer.nextn.hnorm, NULL, LLM_NORM_RMS, il);
|
||||
|
|
@ -18693,13 +18694,13 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
|||
}
|
||||
|
||||
ggml_cgraph * llama_model::build_mtp_graph(const llm_graph_params& params,
|
||||
ggml_tensor* hidden_state_inp, llama_token last_token_id, int n_past) const {
|
||||
llama_token last_token_id, int n_past) const {
|
||||
std::unique_ptr<llm_graph_context> llm;
|
||||
|
||||
switch (arch) {
|
||||
case LLM_ARCH_GLM4_MOE:
|
||||
{
|
||||
llm = std::make_unique<llm_build_glm4_moe_mtp>(*this, params, hidden_state_inp, last_token_id, n_past);
|
||||
llm = std::make_unique<llm_build_glm4_moe_mtp>(*this, params, last_token_id, n_past);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
|
|
@ -19024,10 +19025,3 @@ const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_te
|
|||
return model->tensors_by_name;
|
||||
}
|
||||
|
||||
ggml_cgraph * llama_build_mtp_graph(const llama_model * model, const llm_graph_params & params,
|
||||
ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) {
|
||||
|
||||
return model->build_mtp_graph(params, hidden_state_inp, last_token_id, n_past);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -476,7 +476,7 @@ struct llama_model {
|
|||
// TODO: move this to new llm_arch_model_i interface
|
||||
ggml_cgraph * build_graph(const llm_graph_params & params) const;
|
||||
ggml_cgraph * build_mtp_graph(const llm_graph_params & params,
|
||||
ggml_tensor * hidden_state_inp, llama_token last_token_id, int n_past) const;
|
||||
llama_token last_token_id, int n_past) const;
|
||||
|
||||
private:
|
||||
struct impl;
|
||||
|
|
|
|||
|
|
@ -1278,6 +1278,7 @@ struct server_task_result_apply_lora : server_task_result {
|
|||
}
|
||||
};
|
||||
|
||||
|
||||
struct server_slot {
|
||||
int id;
|
||||
int id_task = -1;
|
||||
|
|
@ -1295,8 +1296,9 @@ struct server_slot {
|
|||
|
||||
common_speculative * spec = nullptr;
|
||||
bool has_mtp = false;
|
||||
std::vector<mtp_kv_update_data> mtp_kv_update_batch;
|
||||
int32_t last_tok_idx = -1;
|
||||
|
||||
|
||||
std::vector<common_adapter_lora_info> lora;
|
||||
|
||||
// the index relative to completion multi-task request
|
||||
|
|
@ -1393,7 +1395,7 @@ struct server_slot {
|
|||
}
|
||||
|
||||
bool need_embd() const {
|
||||
return server_task_type_need_embd(task_type);
|
||||
return server_task_type_need_embd(task_type) || has_mtp;
|
||||
}
|
||||
|
||||
bool need_logits() const {
|
||||
|
|
@ -1569,6 +1571,7 @@ struct server_slot {
|
|||
}
|
||||
};
|
||||
|
||||
|
||||
struct server_metrics {
|
||||
int64_t t_start = 0;
|
||||
|
||||
|
|
@ -1994,7 +1997,7 @@ struct server_context {
|
|||
SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
vocab = llama_model_get_vocab(model);
|
||||
|
||||
n_ctx = llama_n_ctx(ctx);
|
||||
|
|
@ -2124,18 +2127,21 @@ struct server_context {
|
|||
common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// if model has MTP and no draft model is specified...
|
||||
else if (llama_model_n_nextn_layer(model) > 0) {
|
||||
SRV_INF("model has nextn layers = %d\n", llama_model_n_nextn_layer(model));
|
||||
slot.has_mtp = true;
|
||||
|
||||
|
||||
// assume one speculative token (true of all well-known MTP models so far)
|
||||
slot.batch_spec = llama_batch_init(2, 0, 1);
|
||||
SLT_DBG(slot, "batch_spec contains %d tokens\n", slot.batch_spec.n_tokens);
|
||||
|
||||
params_base.speculative.n_min = 0;
|
||||
params_base.speculative.n_max = 1;
|
||||
|
||||
SRV_INF("%s\n", "MTP needs embeddings on decode, enabling");
|
||||
llama_set_embeddings(ctx, true);
|
||||
}
|
||||
|
||||
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
|
||||
|
|
@ -3383,7 +3389,11 @@ struct server_context {
|
|||
// embedding requires all tokens in the batch to be output
|
||||
const bool need_embd = server_task_type_need_embd(slot.task_type);
|
||||
|
||||
if (slot.has_mtp) {
|
||||
slot.mtp_kv_update_batch.push_back({ cur_tok, slot.n_past, batch.n_tokens });
|
||||
}
|
||||
common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd);
|
||||
|
||||
slot.cache_tokens.push_back(cur_tok);
|
||||
|
||||
slot.n_prompt_tokens_processed++;
|
||||
|
|
@ -3533,6 +3543,11 @@ struct server_context {
|
|||
|
||||
const int tok_idx = slot.i_batch - i;
|
||||
|
||||
// This should only trigger on a non-empty update batch once, after prompt processing but not during token generation
|
||||
if (slot.has_mtp) {
|
||||
mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch);
|
||||
}
|
||||
|
||||
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
|
||||
slot.last_tok_idx = tok_idx;
|
||||
|
||||
|
|
@ -3571,8 +3586,6 @@ struct server_context {
|
|||
}
|
||||
}
|
||||
|
||||
SRV_DBG("starting speculative decoding: %d\n", 1);
|
||||
|
||||
// do speculative decoding
|
||||
for (auto & slot : slots) {
|
||||
if (!slot.is_processing() || !slot.can_speculate()) {
|
||||
|
|
@ -3631,13 +3644,9 @@ struct server_context {
|
|||
//draft.reserve(1);
|
||||
//draft.push_back(draft_id);
|
||||
|
||||
for (const auto& str : draft) {
|
||||
SLT_DBG(slot, "%s\n", str);
|
||||
}
|
||||
|
||||
// ignore small drafts
|
||||
if (slot.params.speculative.n_min > (int) draft.size()) {
|
||||
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min);
|
||||
if (slot.params.speculative.n_min > (int)draft.size()) {
|
||||
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int)draft.size(), slot.params.speculative.n_min);
|
||||
|
||||
continue;
|
||||
}
|
||||
|
|
@ -3661,8 +3670,12 @@ struct server_context {
|
|||
// the accepted tokens from the speculation
|
||||
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
|
||||
|
||||
// if slot has mtp
|
||||
// call
|
||||
if (slot.has_mtp) {
|
||||
for (int32_t i = 0; i < ids.size(); ++i) {
|
||||
slot.mtp_kv_update_batch.push_back({ ids[i], slot.n_past + 1 + i, i });
|
||||
}
|
||||
mtp_update_kv_cache(ctx, slot.mtp_kv_update_batch);
|
||||
}
|
||||
|
||||
slot.n_past += ids.size();
|
||||
slot.n_decoded += ids.size();
|
||||
|
|
|
|||
Loading…
Reference in New Issue