kv-cache : add cache for indexer keys (temporary solution)

This commit is contained in:
Stanisław Szymczyk 2026-03-13 17:02:59 +01:00
parent 723f0cef0b
commit 72b7214467
2 changed files with 96 additions and 4 deletions

View File

@ -51,7 +51,7 @@ llama_kv_cache::llama_kv_cache(
auto it = ctx_map.find(buft);
if (it == ctx_map.end()) {
ggml_init_params params = {
/*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()),
/*.mem_size =*/ size_t(3u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
@ -113,6 +113,7 @@ llama_kv_cache::llama_kv_cache(
// [TAG_V_CACHE_VARIABLE]
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max();
const uint32_t n_embd_indexer_head = hparams.indexer_head_size;
const char * dev_name = "CPU";
@ -134,24 +135,29 @@ llama_kv_cache::llama_kv_cache(
const bool has_k = true;
const bool has_v = !is_mla;
const bool has_ik = hparams.indexer_top_k > 0;
ggml_tensor * k = has_k ? ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream) : nullptr;
ggml_tensor * v = has_v ? ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream) : nullptr;
ggml_tensor * ik = has_ik ? ggml_new_tensor_3d(ctx, type_k, n_embd_indexer_head, kv_size, n_stream) : nullptr;
has_k && ggml_format_name(k, "cache_k_l%d", il);
has_v && ggml_format_name(v, "cache_v_l%d", il);
has_ik && ggml_format_name(ik, "cache_ik_l%d", il);
std::vector<ggml_tensor *> k_stream;
std::vector<ggml_tensor *> v_stream;
std::vector<ggml_tensor *> ik_stream;
for (uint32_t s = 0; s < n_stream; ++s) {
k_stream.push_back(has_k ? ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]) : nullptr);
v_stream.push_back(has_v ? ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]) : nullptr);
ik_stream.push_back(has_ik ? ggml_view_2d(ctx, ik, n_embd_indexer_head, kv_size, ik->nb[1], s*ik->nb[2]) : nullptr);
}
map_layer_ids[il] = layers.size();
layers.push_back({ il, k, v, k_stream, v_stream, });
layers.push_back({ il, k, v, ik, k_stream, v_stream, ik_stream });
}
if (reuse) {
@ -202,11 +208,13 @@ llama_kv_cache::llama_kv_cache(
{
const size_t memory_size_k = size_k_bytes();
const size_t memory_size_v = size_v_bytes();
const size_t memory_size_ik = size_ik_bytes();
LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB, IK (%s): %7.2f MiB\n", __func__,
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_stream,
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f),
ggml_type_name(type_k), (float)memory_size_ik / (1024.0f * 1024.0f));
}
const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
@ -656,6 +664,10 @@ bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_co
if (layer.v_stream[ssrc]) {
ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]);
}
if (layer.ik_stream[ssrc]) {
ggml_backend_tensor_copy(layer.ik_stream[ssrc], layer.ik_stream[sdst]);
}
}
}
}
@ -1072,6 +1084,26 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k
ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
}
ggml_tensor * llama_kv_cache::get_ik(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
const int32_t ikv = map_layer_ids.at(il);
auto * ik = layers[ikv].ik;
const uint64_t kv_size = get_size();
const uint64_t n_embd_indexer_head = ik->ne[0];
assert(n_embd_indexer_head == hparams.indexer_head_size);
const uint32_t ns = sinfo.s1 - sinfo.s0 + 1;
return ggml_view_4d(ctx, ik,
n_embd_indexer_head, 1, n_kv, ns,
ggml_row_size(ik->type, n_embd_indexer_head),
ggml_row_size(ik->type, n_embd_indexer_head),
ggml_row_size(ik->type, n_embd_indexer_head*kv_size),
ggml_row_size(ik->type, n_embd_indexer_head*kv_size)*sinfo.s0);
}
ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
GGML_UNUSED(sinfo);
@ -1163,6 +1195,41 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm
return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
}
ggml_tensor * llama_kv_cache::cpy_ik(ggml_context * ctx, ggml_tensor * ik_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
GGML_UNUSED(sinfo);
const int32_t ikv = map_layer_ids.at(il);
ggml_tensor * ik = layers[ikv].ik;
const int64_t n_embd_indexer_head = ik_cur->ne[0];
const int64_t n_head = ik_cur->ne[1];
const int64_t n_tokens = ik_cur->ne[2];
const int64_t n_embd_gqa = n_embd_indexer_head*n_head;
// we can merge dims 0 and 1
// TODO: add ggml helper function for this?
GGML_ASSERT(ggml_row_size(ik_cur->type, n_embd_indexer_head) == ik_cur->nb[1]);
ik_cur = ggml_view_2d(ctx, ik_cur, n_embd_gqa, n_tokens, ik_cur->nb[2], 0);
const int64_t n_stream = ik->ne[2];
if (n_stream > 1) {
const int64_t kv_size = get_size();
assert(n_embd_gqa == ik->ne[0]);
assert(kv_size == ik->ne[1]);
// merge the buffer across all streams because the idxs are global
ik = ggml_reshape_2d(ctx, ik, n_embd_gqa, kv_size*n_stream);
}
// store the current K values into the cache
return ggml_set_rows(ctx, ik, ik_cur, k_idxs);
}
ggml_tensor * llama_kv_cache::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
const uint32_t n_tokens = ubatch.n_tokens;
@ -1537,6 +1604,16 @@ size_t llama_kv_cache::size_v_bytes() const {
return size_v_bytes;
}
size_t llama_kv_cache::size_ik_bytes() const {
size_t size_ik_bytes = 0;
for (const auto & layer : layers) {
size_ik_bytes += ggml_nbytes(layer.ik);
}
return size_ik_bytes;
}
ggml_tensor * llama_kv_cache::build_rope_shift(
const llama_cparams & cparams,
ggml_context * ctx,
@ -2242,6 +2319,10 @@ ggml_tensor * llama_kv_cache_context::get_v(ggml_context * ctx, int32_t il) cons
return kv->get_v(ctx, il, n_kv, sinfos[i_cur]);
}
ggml_tensor * llama_kv_cache_context::get_ik(ggml_context * ctx, int32_t il) const {
return kv->get_ik(ctx, il, n_kv, sinfos[i_cur]);
}
ggml_tensor * llama_kv_cache_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
}
@ -2250,6 +2331,10 @@ ggml_tensor * llama_kv_cache_context::cpy_v(ggml_context * ctx, ggml_tensor * v_
return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
}
ggml_tensor * llama_kv_cache_context::cpy_ik(ggml_context * ctx, ggml_tensor * ik_cur, ggml_tensor * k_idxs, int32_t il) const {
return kv->cpy_ik(ctx, ik_cur, k_idxs, il, sinfos[i_cur]);
}
ggml_tensor * llama_kv_cache_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
return kv->build_input_k_idxs(ctx, ubatch);
}

View File

@ -161,10 +161,12 @@ public:
// get views of the current state of the cache
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
ggml_tensor * get_ik(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
// store k_cur and v_cur in the cache based on the provided head location
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const;
ggml_tensor * cpy_ik(ggml_context * ctx, ggml_tensor * ik_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
//
// preparation API
@ -210,9 +212,11 @@ private:
ggml_tensor * k;
ggml_tensor * v;
ggml_tensor * ik;
std::vector<ggml_tensor *> k_stream;
std::vector<ggml_tensor *> v_stream;
std::vector<ggml_tensor *> ik_stream;
};
bool v_trans = true; // the value tensor is transposed
@ -256,6 +260,7 @@ private:
size_t size_k_bytes() const;
size_t size_v_bytes() const;
size_t size_ik_bytes() const;
ggml_tensor * build_rope_shift(
const llama_cparams & cparams,
@ -331,6 +336,7 @@ public:
// get views of the current state of the cache
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
ggml_tensor * get_ik(ggml_context * ctx, int32_t il) const;
// store k_cur and v_cur in the cache based on the provided head location
// note: the heads in k_cur and v_cur should be layed out contiguously in memory
@ -340,6 +346,7 @@ public:
// - v_idxs [n_tokens] or [n_tokens*n_embd_v_gqa] depending if V cache is transposed
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;
ggml_tensor * cpy_ik(ggml_context * ctx, ggml_tensor * ik_cur, ggml_tensor * k_idxs, int32_t il) const;
// create destination indices for each head of the current batch for where it would be written in the KV cache
// the indices address the global KV cache (not per stream) - this is not relevant for the user of this API, but