From 72b721446726fb029b83bb746566df187079bf60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stanis=C5=82aw=20Szymczyk?= Date: Fri, 13 Mar 2026 17:02:59 +0100 Subject: [PATCH] kv-cache : add cache for indexer keys (temporary solution) --- src/llama-kv-cache.cpp | 93 ++++++++++++++++++++++++++++++++++++++++-- src/llama-kv-cache.h | 7 ++++ 2 files changed, 96 insertions(+), 4 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 82fe58fac4..bea96501f9 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -51,7 +51,7 @@ llama_kv_cache::llama_kv_cache( auto it = ctx_map.find(buft); if (it == ctx_map.end()) { ggml_init_params params = { - /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()), + /*.mem_size =*/ size_t(3u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; @@ -113,6 +113,7 @@ llama_kv_cache::llama_kv_cache( // [TAG_V_CACHE_VARIABLE] const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max(); + const uint32_t n_embd_indexer_head = hparams.indexer_head_size; const char * dev_name = "CPU"; @@ -134,24 +135,29 @@ llama_kv_cache::llama_kv_cache( const bool has_k = true; const bool has_v = !is_mla; + const bool has_ik = hparams.indexer_top_k > 0; ggml_tensor * k = has_k ? ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream) : nullptr; ggml_tensor * v = has_v ? ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream) : nullptr; + ggml_tensor * ik = has_ik ? ggml_new_tensor_3d(ctx, type_k, n_embd_indexer_head, kv_size, n_stream) : nullptr; has_k && ggml_format_name(k, "cache_k_l%d", il); has_v && ggml_format_name(v, "cache_v_l%d", il); + has_ik && ggml_format_name(ik, "cache_ik_l%d", il); std::vector k_stream; std::vector v_stream; + std::vector ik_stream; for (uint32_t s = 0; s < n_stream; ++s) { k_stream.push_back(has_k ? ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]) : nullptr); v_stream.push_back(has_v ? ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]) : nullptr); + ik_stream.push_back(has_ik ? ggml_view_2d(ctx, ik, n_embd_indexer_head, kv_size, ik->nb[1], s*ik->nb[2]) : nullptr); } map_layer_ids[il] = layers.size(); - layers.push_back({ il, k, v, k_stream, v_stream, }); + layers.push_back({ il, k, v, ik, k_stream, v_stream, ik_stream }); } if (reuse) { @@ -202,11 +208,13 @@ llama_kv_cache::llama_kv_cache( { const size_t memory_size_k = size_k_bytes(); const size_t memory_size_v = size_v_bytes(); + const size_t memory_size_ik = size_ik_bytes(); - LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB, IK (%s): %7.2f MiB\n", __func__, (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_stream, ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), - ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); + ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_ik / (1024.0f * 1024.0f)); } const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG"); @@ -656,6 +664,10 @@ bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_co if (layer.v_stream[ssrc]) { ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]); } + + if (layer.ik_stream[ssrc]) { + ggml_backend_tensor_copy(layer.ik_stream[ssrc], layer.ik_stream[sdst]); + } } } } @@ -1072,6 +1084,26 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0); } +ggml_tensor * llama_kv_cache::get_ik(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const { + const int32_t ikv = map_layer_ids.at(il); + + auto * ik = layers[ikv].ik; + + const uint64_t kv_size = get_size(); + const uint64_t n_embd_indexer_head = ik->ne[0]; + + assert(n_embd_indexer_head == hparams.indexer_head_size); + + const uint32_t ns = sinfo.s1 - sinfo.s0 + 1; + + return ggml_view_4d(ctx, ik, + n_embd_indexer_head, 1, n_kv, ns, + ggml_row_size(ik->type, n_embd_indexer_head), + ggml_row_size(ik->type, n_embd_indexer_head), + ggml_row_size(ik->type, n_embd_indexer_head*kv_size), + ggml_row_size(ik->type, n_embd_indexer_head*kv_size)*sinfo.s0); +} + ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const { GGML_UNUSED(sinfo); @@ -1163,6 +1195,41 @@ ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggm return ggml_set_rows(ctx, v_view, v_cur, v_idxs); } +ggml_tensor * llama_kv_cache::cpy_ik(ggml_context * ctx, ggml_tensor * ik_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const { + GGML_UNUSED(sinfo); + + const int32_t ikv = map_layer_ids.at(il); + + ggml_tensor * ik = layers[ikv].ik; + + const int64_t n_embd_indexer_head = ik_cur->ne[0]; + const int64_t n_head = ik_cur->ne[1]; + const int64_t n_tokens = ik_cur->ne[2]; + + const int64_t n_embd_gqa = n_embd_indexer_head*n_head; + + // we can merge dims 0 and 1 + // TODO: add ggml helper function for this? + GGML_ASSERT(ggml_row_size(ik_cur->type, n_embd_indexer_head) == ik_cur->nb[1]); + + ik_cur = ggml_view_2d(ctx, ik_cur, n_embd_gqa, n_tokens, ik_cur->nb[2], 0); + + const int64_t n_stream = ik->ne[2]; + + if (n_stream > 1) { + const int64_t kv_size = get_size(); + + assert(n_embd_gqa == ik->ne[0]); + assert(kv_size == ik->ne[1]); + + // merge the buffer across all streams because the idxs are global + ik = ggml_reshape_2d(ctx, ik, n_embd_gqa, kv_size*n_stream); + } + + // store the current K values into the cache + return ggml_set_rows(ctx, ik, ik_cur, k_idxs); +} + ggml_tensor * llama_kv_cache::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { const uint32_t n_tokens = ubatch.n_tokens; @@ -1537,6 +1604,16 @@ size_t llama_kv_cache::size_v_bytes() const { return size_v_bytes; } +size_t llama_kv_cache::size_ik_bytes() const { + size_t size_ik_bytes = 0; + + for (const auto & layer : layers) { + size_ik_bytes += ggml_nbytes(layer.ik); + } + + return size_ik_bytes; +} + ggml_tensor * llama_kv_cache::build_rope_shift( const llama_cparams & cparams, ggml_context * ctx, @@ -2242,6 +2319,10 @@ ggml_tensor * llama_kv_cache_context::get_v(ggml_context * ctx, int32_t il) cons return kv->get_v(ctx, il, n_kv, sinfos[i_cur]); } +ggml_tensor * llama_kv_cache_context::get_ik(ggml_context * ctx, int32_t il) const { + return kv->get_ik(ctx, il, n_kv, sinfos[i_cur]); +} + ggml_tensor * llama_kv_cache_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const { return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]); } @@ -2250,6 +2331,10 @@ ggml_tensor * llama_kv_cache_context::cpy_v(ggml_context * ctx, ggml_tensor * v_ return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]); } +ggml_tensor * llama_kv_cache_context::cpy_ik(ggml_context * ctx, ggml_tensor * ik_cur, ggml_tensor * k_idxs, int32_t il) const { + return kv->cpy_ik(ctx, ik_cur, k_idxs, il, sinfos[i_cur]); +} + ggml_tensor * llama_kv_cache_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { return kv->build_input_k_idxs(ctx, ubatch); } diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 33c78c5f21..6e47b40563 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -161,10 +161,12 @@ public: // get views of the current state of the cache ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; + ggml_tensor * get_ik(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; // store k_cur and v_cur in the cache based on the provided head location ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const; ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const; + ggml_tensor * cpy_ik(ggml_context * ctx, ggml_tensor * ik_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const; // // preparation API @@ -210,9 +212,11 @@ private: ggml_tensor * k; ggml_tensor * v; + ggml_tensor * ik; std::vector k_stream; std::vector v_stream; + std::vector ik_stream; }; bool v_trans = true; // the value tensor is transposed @@ -256,6 +260,7 @@ private: size_t size_k_bytes() const; size_t size_v_bytes() const; + size_t size_ik_bytes() const; ggml_tensor * build_rope_shift( const llama_cparams & cparams, @@ -331,6 +336,7 @@ public: // get views of the current state of the cache ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; + ggml_tensor * get_ik(ggml_context * ctx, int32_t il) const; // store k_cur and v_cur in the cache based on the provided head location // note: the heads in k_cur and v_cur should be layed out contiguously in memory @@ -340,6 +346,7 @@ public: // - v_idxs [n_tokens] or [n_tokens*n_embd_v_gqa] depending if V cache is transposed ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const; ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const; + ggml_tensor * cpy_ik(ggml_context * ctx, ggml_tensor * ik_cur, ggml_tensor * k_idxs, int32_t il) const; // create destination indices for each head of the current batch for where it would be written in the KV cache // the indices address the global KV cache (not per stream) - this is not relevant for the user of this API, but