kludge-y kv cache management of mtp layer

This commit is contained in:
Aaron Lee 2025-08-19 01:50:34 -04:00
parent 382135aa36
commit d72f9d5691
4 changed files with 32 additions and 5 deletions

View File

@ -7,6 +7,7 @@
#include "llama-mmap.h"
#include "llama-model.h"
#include "llama-graph.h"
#include "llama-kv-cache-unified.h"
#include <cinttypes>
#include <cstring>
@ -3000,7 +3001,20 @@ void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
const auto * model = llama_get_model(ctx);
auto res_mtp = std::make_unique<llm_graph_result>(ctx->graph_max_nodes());
llama_memory_context_ptr mctx = ctx->mtp_memory_batch(batch_inp);
std::unique_ptr<llama_memory_context_i> mctx = ctx->mtp_memory_batch(batch_inp);
std::vector<uint32_t> idxs;
idxs.push_back(n_past);
llama_kv_cache_unified::slot_info sinfo = {
/*.s0 =*/ 0,
/*.s1 =*/ 0,
/*.strm =*/ { 0 },
/*.idxs =*/ { idxs },
};
llama_kv_cache_unified::slot_info_vec_t sinfos;
sinfos.push_back(sinfo);
static_cast<llama_kv_cache_unified_context*>(mctx.get())->set_sinfos(sinfos);
const auto& ubatch_mtp = mctx->get_ubatch();
//llama_ubatch ubatch_mtp;
@ -3012,9 +3026,10 @@ void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
auto * last_embd = ctx->get_embeddings_ith(last_tok_idx);
if (mctx && !mctx->apply()) {
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
}
//if (mctx && !mctx->set_n_kv()) {
// LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
//}
static_cast<llama_kv_cache_unified_context*>(mctx.get())->set_n_kv();
auto * gf = model->build_mtp_graph(*params_mtp, last_token_id, n_past);

View File

@ -2322,6 +2322,11 @@ bool llama_kv_cache_unified_context::apply() {
return true;
}
void llama_kv_cache_unified_context::set_n_kv() {
n_kv = kv->get_n_kv();
}
llama_memory_status llama_kv_cache_unified_context::get_status() const {
return status;
}
@ -2384,6 +2389,10 @@ void llama_kv_cache_unified_context::set_input_pos_bucket(ggml_tensor * dst, con
kv->set_input_pos_bucket(dst, ubatch);
}
void llama_kv_cache_unified_context::set_sinfos(llama_kv_cache_unified::slot_info_vec_t new_sinfos) {
sinfos = new_sinfos;
}
uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
// the FA kernels require padding to avoid extra runtime boundary checks
return cparams.flash_attn ? 256u : 32u;

View File

@ -340,6 +340,7 @@ public:
//
uint32_t get_n_kv() const;
void set_n_kv();
// TODO: temporary
bool get_supports_set_rows() const;
@ -362,6 +363,8 @@ public:
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
void set_sinfos(slot_info_vec_t new_sinfos);
private:
llama_memory_status status;

View File

@ -3545,7 +3545,7 @@ struct server_context {
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
slot.last_tok_idx = tok_idx;
SRV_INF("main loop sampled token: '%s'\n", common_token_to_piece(ctx, id, true).c_str());
//SRV_INF("main loop sampled token: '%s'\n", common_token_to_piece(ctx, id, true).c_str());
slot.i_batch = -1;