diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 34d514387b..a09ac6d447 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -7,6 +7,7 @@ #include "llama-mmap.h" #include "llama-model.h" #include "llama-graph.h" +#include "llama-kv-cache-unified.h" #include #include @@ -3000,7 +3001,20 @@ void llama_build_and_execute_mtp_graph(struct llama_context * ctx, const auto * model = llama_get_model(ctx); auto res_mtp = std::make_unique(ctx->graph_max_nodes()); - llama_memory_context_ptr mctx = ctx->mtp_memory_batch(batch_inp); + std::unique_ptr mctx = ctx->mtp_memory_batch(batch_inp); + + std::vector idxs; + idxs.push_back(n_past); + llama_kv_cache_unified::slot_info sinfo = { + /*.s0 =*/ 0, + /*.s1 =*/ 0, + /*.strm =*/ { 0 }, + /*.idxs =*/ { idxs }, + }; + llama_kv_cache_unified::slot_info_vec_t sinfos; + sinfos.push_back(sinfo); + + static_cast(mctx.get())->set_sinfos(sinfos); const auto& ubatch_mtp = mctx->get_ubatch(); //llama_ubatch ubatch_mtp; @@ -3012,9 +3026,10 @@ void llama_build_and_execute_mtp_graph(struct llama_context * ctx, auto * last_embd = ctx->get_embeddings_ith(last_tok_idx); - if (mctx && !mctx->apply()) { - LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); - } + //if (mctx && !mctx->set_n_kv()) { + // LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); + //} + static_cast(mctx.get())->set_n_kv(); auto * gf = model->build_mtp_graph(*params_mtp, last_token_id, n_past); diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index ed6cf969d4..53466264cd 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -2322,6 +2322,11 @@ bool llama_kv_cache_unified_context::apply() { return true; } +void llama_kv_cache_unified_context::set_n_kv() { + n_kv = kv->get_n_kv(); +} + + llama_memory_status llama_kv_cache_unified_context::get_status() const { return status; } @@ -2384,6 +2389,10 @@ void llama_kv_cache_unified_context::set_input_pos_bucket(ggml_tensor * dst, con kv->set_input_pos_bucket(dst, ubatch); } +void llama_kv_cache_unified_context::set_sinfos(llama_kv_cache_unified::slot_info_vec_t new_sinfos) { + sinfos = new_sinfos; +} + uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) { // the FA kernels require padding to avoid extra runtime boundary checks return cparams.flash_attn ? 256u : 32u; diff --git a/src/llama-kv-cache-unified.h b/src/llama-kv-cache-unified.h index 342a675962..c02607c2d0 100644 --- a/src/llama-kv-cache-unified.h +++ b/src/llama-kv-cache-unified.h @@ -340,6 +340,7 @@ public: // uint32_t get_n_kv() const; + void set_n_kv(); // TODO: temporary bool get_supports_set_rows() const; @@ -362,6 +363,8 @@ public: void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; + void set_sinfos(slot_info_vec_t new_sinfos); + private: llama_memory_status status; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index e323f7b521..1191564dd2 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3545,7 +3545,7 @@ struct server_context { llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); slot.last_tok_idx = tok_idx; - SRV_INF("main loop sampled token: '%s'\n", common_token_to_piece(ctx, id, true).c_str()); + //SRV_INF("main loop sampled token: '%s'\n", common_token_to_piece(ctx, id, true).c_str()); slot.i_batch = -1;