kludge-y kv cache management of mtp layer
This commit is contained in:
parent
382135aa36
commit
d72f9d5691
|
|
@ -7,6 +7,7 @@
|
|||
#include "llama-mmap.h"
|
||||
#include "llama-model.h"
|
||||
#include "llama-graph.h"
|
||||
#include "llama-kv-cache-unified.h"
|
||||
|
||||
#include <cinttypes>
|
||||
#include <cstring>
|
||||
|
|
@ -3000,7 +3001,20 @@ void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
|
|||
const auto * model = llama_get_model(ctx);
|
||||
|
||||
auto res_mtp = std::make_unique<llm_graph_result>(ctx->graph_max_nodes());
|
||||
llama_memory_context_ptr mctx = ctx->mtp_memory_batch(batch_inp);
|
||||
std::unique_ptr<llama_memory_context_i> mctx = ctx->mtp_memory_batch(batch_inp);
|
||||
|
||||
std::vector<uint32_t> idxs;
|
||||
idxs.push_back(n_past);
|
||||
llama_kv_cache_unified::slot_info sinfo = {
|
||||
/*.s0 =*/ 0,
|
||||
/*.s1 =*/ 0,
|
||||
/*.strm =*/ { 0 },
|
||||
/*.idxs =*/ { idxs },
|
||||
};
|
||||
llama_kv_cache_unified::slot_info_vec_t sinfos;
|
||||
sinfos.push_back(sinfo);
|
||||
|
||||
static_cast<llama_kv_cache_unified_context*>(mctx.get())->set_sinfos(sinfos);
|
||||
const auto& ubatch_mtp = mctx->get_ubatch();
|
||||
|
||||
//llama_ubatch ubatch_mtp;
|
||||
|
|
@ -3012,9 +3026,10 @@ void llama_build_and_execute_mtp_graph(struct llama_context * ctx,
|
|||
|
||||
auto * last_embd = ctx->get_embeddings_ith(last_tok_idx);
|
||||
|
||||
if (mctx && !mctx->apply()) {
|
||||
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
|
||||
}
|
||||
//if (mctx && !mctx->set_n_kv()) {
|
||||
// LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
|
||||
//}
|
||||
static_cast<llama_kv_cache_unified_context*>(mctx.get())->set_n_kv();
|
||||
|
||||
auto * gf = model->build_mtp_graph(*params_mtp, last_token_id, n_past);
|
||||
|
||||
|
|
|
|||
|
|
@ -2322,6 +2322,11 @@ bool llama_kv_cache_unified_context::apply() {
|
|||
return true;
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_context::set_n_kv() {
|
||||
n_kv = kv->get_n_kv();
|
||||
}
|
||||
|
||||
|
||||
llama_memory_status llama_kv_cache_unified_context::get_status() const {
|
||||
return status;
|
||||
}
|
||||
|
|
@ -2384,6 +2389,10 @@ void llama_kv_cache_unified_context::set_input_pos_bucket(ggml_tensor * dst, con
|
|||
kv->set_input_pos_bucket(dst, ubatch);
|
||||
}
|
||||
|
||||
void llama_kv_cache_unified_context::set_sinfos(llama_kv_cache_unified::slot_info_vec_t new_sinfos) {
|
||||
sinfos = new_sinfos;
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
|
||||
// the FA kernels require padding to avoid extra runtime boundary checks
|
||||
return cparams.flash_attn ? 256u : 32u;
|
||||
|
|
|
|||
|
|
@ -340,6 +340,7 @@ public:
|
|||
//
|
||||
|
||||
uint32_t get_n_kv() const;
|
||||
void set_n_kv();
|
||||
|
||||
// TODO: temporary
|
||||
bool get_supports_set_rows() const;
|
||||
|
|
@ -362,6 +363,8 @@ public:
|
|||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
|
||||
void set_sinfos(slot_info_vec_t new_sinfos);
|
||||
|
||||
private:
|
||||
llama_memory_status status;
|
||||
|
||||
|
|
|
|||
|
|
@ -3545,7 +3545,7 @@ struct server_context {
|
|||
|
||||
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
|
||||
slot.last_tok_idx = tok_idx;
|
||||
SRV_INF("main loop sampled token: '%s'\n", common_token_to_piece(ctx, id, true).c_str());
|
||||
//SRV_INF("main loop sampled token: '%s'\n", common_token_to_piece(ctx, id, true).c_str());
|
||||
|
||||
slot.i_batch = -1;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue