kv-cache : support layer reuse (#15504)

* kv-cache : support layer reuse

ggml-ci

* cont : update comments [no ci]
This commit is contained in:
Georgi Gerganov 2025-08-24 13:07:07 +03:00 committed by GitHub
parent c9a24fb932
commit b730706a49
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 203 additions and 136 deletions

View File

@ -153,3 +153,28 @@ bool llama_hparams::is_swa(uint32_t il) const {
GGML_ABORT("fatal error");
}
bool llama_hparams::has_kv(uint32_t il) const {
if (n_layer_kv_from_start >= 0) {
if (il < (uint32_t) n_layer_kv_from_start) {
return true;
}
return false;
}
// by default, all layers have kv
return true;
}
uint32_t llama_hparams::n_layer_kv() const {
uint32_t res = 0;
for (uint32_t il = 0; il < n_layer; ++il) {
if (has_kv(il)) {
res++;
}
}
return res;
}

View File

@ -41,6 +41,7 @@ struct llama_hparams {
uint32_t n_embd;
uint32_t n_embd_features = 0;
uint32_t n_layer;
int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache
uint32_t n_rot;
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
@ -221,6 +222,11 @@ struct llama_hparams {
uint32_t n_pos_per_embd() const;
bool is_swa(uint32_t il) const;
bool has_kv(uint32_t il) const;
// number of layers for which has_kv() returns true
uint32_t n_layer_kv() const;
};
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");

View File

@ -22,9 +22,26 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_ubatch,
uint32_t n_pad) : hparams(model.hparams), unified(unified) {
llama_kv_cache::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
llama_kv_cache::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
uint32_t n_pad,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) {
// chain filters
const layer_filter_cb filter_base = [&](int32_t il) {
if (filter && !filter(il)) {
return false;
}
return !model.hparams.is_swa(il);
};
const layer_filter_cb filter_swa = [&](int32_t il) {
if (filter && !filter(il)) {
return false;
}
return model.hparams.is_swa(il);
};
const uint32_t size_base = kv_size;
@ -41,16 +58,16 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
kv_base = std::make_unique<llama_kv_cache>(
model, std::move(filter_base), type_k, type_v,
model, type_k, type_v,
v_trans, offload, unified, size_base, n_seq_max, n_pad,
0, LLAMA_SWA_TYPE_NONE);
0, LLAMA_SWA_TYPE_NONE, filter_base, reuse);
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
kv_swa = std::make_unique<llama_kv_cache>(
model, std::move(filter_swa), type_k, type_v,
model, type_k, type_v,
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
hparams.n_swa, hparams.swa_type);
hparams.n_swa, hparams.swa_type, filter_swa, reuse);
}
void llama_kv_cache_iswa::clear(bool data) {

View File

@ -20,11 +20,13 @@ public:
bool v_trans,
bool offload,
bool swa_full,
bool ,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_ubatch,
uint32_t n_pad);
uint32_t n_pad,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse);
~llama_kv_cache_iswa() = default;

View File

@ -17,32 +17,25 @@
//
llama_kv_cache::llama_kv_cache(
const llama_model & model,
layer_filter_cb && filter,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type) :
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse) :
model(model), hparams(model.hparams), v_trans(v_trans),
n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
GGML_ASSERT(kv_size % n_pad == 0);
// TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
auto n_layer_cache = hparams.n_layer;
if (model.arch == LLM_ARCH_GEMMA3N) {
n_layer_cache = 20;
}
if (model.arch == LLM_ARCH_GLM4_MOE) {
// GLM-4.5: Only process up to last layer, skip final NextN layer
n_layer_cache = hparams.n_layer - hparams.nextn_predict_layers;
}
const uint32_t n_layer_kv = hparams.n_layer_kv();
// create a context for each buffer type
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
@ -50,7 +43,7 @@ llama_kv_cache::llama_kv_cache(
auto it = ctx_map.find(buft);
if (it == ctx_map.end()) {
ggml_init_params params = {
/*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_cache*ggml_tensor_overhead()),
/*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
@ -97,9 +90,14 @@ llama_kv_cache::llama_kv_cache(
__func__, hparams.n_embd_v_gqa_max());
}
for (uint32_t il = 0; il < n_layer_cache; il++) {
for (uint32_t il = 0; il < hparams.n_layer; il++) {
if (!hparams.has_kv(il)) {
LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il);
continue;
}
if (filter && !filter(il)) {
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
LLAMA_LOG_DEBUG("%s: layer %3d: filtered\n", __func__, il);
continue;
}
@ -147,23 +145,27 @@ llama_kv_cache::llama_kv_cache(
layers.push_back({ il, k, v, k_stream, v_stream, });
}
// TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
if (model.arch == LLM_ARCH_GEMMA3N) {
LLAMA_LOG_DEBUG("%s: GEMMA3N: reuse layers [%d, %d]\n", __func__, n_layer_cache, hparams.n_layer - 1);
if (reuse) {
LLAMA_LOG_DEBUG("%s: reusing layers:\n", __func__);
for (uint32_t il = n_layer_cache; il < hparams.n_layer; il++) {
if (filter && !filter(il)) {
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
for (uint32_t il = 0; il < hparams.n_layer; il++) {
const int32_t il_reuse = reuse(il);
if (il_reuse < 0) {
LLAMA_LOG_DEBUG("%s: - layer %3d: no reuse\n", __func__, il);
continue;
}
const bool is_swa = hparams.is_swa(il);
const uint32_t il_reuse = n_layer_cache - (is_swa ? 2 : 1);
if (filter && !filter(il)) {
LLAMA_LOG_DEBUG("%s: - layer %3d: filtered\n", __func__, il);
continue;
}
GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end());
map_layer_ids[il] = map_layer_ids[il_reuse];
LLAMA_LOG_DEBUG("%s: layer %3d: reuse layer %d, isw = %d\n", __func__, il, il_reuse, is_swa);
LLAMA_LOG_DEBUG("%s: - layer %3d: reuse layer %d, is_swa = %d\n", __func__, il, il_reuse, hparams.is_swa(il));
}
}

View File

@ -21,9 +21,6 @@ class llama_kv_cache : public llama_memory_i {
public:
static uint32_t get_padding(const llama_cparams & cparams);
// this callback is used to filter out layers that should not be included in the cache
using layer_filter_cb = std::function<bool(int32_t il)>;
struct stream_copy_info {
bool empty() const {
assert(ssrc.size() == sdst.size());
@ -82,18 +79,19 @@ public:
using slot_info_vec_t = std::vector<slot_info>;
llama_kv_cache(
const llama_model & model,
layer_filter_cb && filter,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type);
const llama_model & model,
ggml_type type_k,
ggml_type type_v,
bool v_trans,
bool offload,
bool unified,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
const layer_filter_cb & filter,
const layer_reuse_cb & reuse);
~llama_kv_cache() = default;

View File

@ -9,32 +9,29 @@
//
llama_memory_hybrid::llama_memory_hybrid(
const llama_model & model,
/* attn */
ggml_type type_k,
ggml_type type_v,
bool v_trans,
uint32_t kv_size,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
/* recurrent */
ggml_type type_r,
ggml_type type_s,
uint32_t rs_size,
/* common */
uint32_t n_seq_max,
bool offload,
bool unified,
/* layer filters */
layer_filter_cb && filter_attn,
layer_filter_cb && filter_recr) :
const llama_model & model,
/* attn */
ggml_type type_k,
ggml_type type_v,
bool v_trans,
uint32_t kv_size,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
/* recurrent */
ggml_type type_r,
ggml_type type_s,
uint32_t rs_size,
/* common */
uint32_t n_seq_max,
bool offload,
bool unified,
/* layer filters */
const layer_filter_cb & filter_attn,
const layer_filter_cb & filter_recr) :
hparams(model.hparams),
mem_attn(new llama_kv_cache(
model,
filter_attn == nullptr ?
[&](int32_t il) { return !hparams.is_recurrent(il); }
: filter_attn,
type_k,
type_v,
v_trans,
@ -44,18 +41,22 @@ llama_memory_hybrid::llama_memory_hybrid(
n_seq_max,
n_pad,
n_swa,
swa_type
swa_type,
filter_attn == nullptr ?
[&](int32_t il) { return !hparams.is_recurrent(il); }
: filter_attn,
nullptr
)),
mem_recr(new llama_memory_recurrent(
model,
filter_recr == nullptr ?
[&](int32_t il) { return hparams.is_recurrent(il); }
: filter_recr,
type_r,
type_s,
offload,
rs_size,
n_seq_max
n_seq_max,
filter_recr == nullptr ?
[&](int32_t il) { return hparams.is_recurrent(il); }
: filter_recr
)) {}
llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {

View File

@ -18,31 +18,27 @@
class llama_memory_hybrid : public llama_memory_i {
public:
// this callback is used to filter out layers that should not be included in the cache
using layer_filter_cb = std::function<bool(int32_t il)>;
llama_memory_hybrid(
const llama_model & model,
/* attn */
ggml_type type_k,
ggml_type type_v,
bool v_trans,
uint32_t kv_size,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
/* recurrent */
ggml_type type_r,
ggml_type type_s,
uint32_t rs_size,
/* common */
uint32_t n_seq_max,
bool offload,
bool unified,
/* layer filters */
layer_filter_cb && filter_attn = nullptr,
layer_filter_cb && filter_recr = nullptr);
ggml_type type_k,
ggml_type type_v,
bool v_trans,
uint32_t kv_size,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type,
/* recurrent */
ggml_type type_r,
ggml_type type_s,
uint32_t rs_size,
/* common */
uint32_t n_seq_max,
bool offload,
bool unified,
/* layer filters */
const layer_filter_cb & filter_attn = nullptr,
const layer_filter_cb & filter_recr = nullptr);
~llama_memory_hybrid() = default;

View File

@ -16,13 +16,13 @@
//
llama_memory_recurrent::llama_memory_recurrent(
const llama_model & model,
layer_filter_cb && filter,
ggml_type type_r,
ggml_type type_s,
bool offload,
uint32_t mem_size,
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
const llama_model & model,
ggml_type type_r,
ggml_type type_s,
bool offload,
uint32_t mem_size,
uint32_t n_seq_max,
const layer_filter_cb & filter) : hparams(model.hparams), n_seq_max(n_seq_max) {
const int32_t n_layer = hparams.n_layer;
head = 0;

View File

@ -15,18 +15,14 @@
// see the implementation of llama_kv_cache_context_i for an example how to do it
class llama_memory_recurrent : public llama_memory_i {
public:
// this callback is used to filter out layers that should not be included in the cache
using layer_filter_cb = std::function<bool(int32_t il)>;
llama_memory_recurrent(
const llama_model & model,
layer_filter_cb && filter,
ggml_type type_r,
ggml_type type_s,
bool offload,
uint32_t mem_size,
uint32_t n_seq_max);
const llama_model & model,
ggml_type type_r,
ggml_type type_s,
bool offload,
uint32_t mem_size,
uint32_t n_seq_max,
const layer_filter_cb & filter);
~llama_memory_recurrent() = default;

View File

@ -3,6 +3,7 @@
#include "llama.h"
#include <memory>
#include <functional>
struct llama_ubatch;
@ -64,6 +65,13 @@ using llama_memory_context_ptr = std::unique_ptr<llama_memory_context_i>;
// general concept of LLM memory
// the KV cache is a type of LLM memory, but there can be other types
struct llama_memory_i {
// this callback is used to filter out layers that should not be included in the cache
using layer_filter_cb = std::function<bool(int32_t il)>;
// this callback is used to specify which layers should reuse memory from other layers
// return negative value to indicate that the layer il should not reuse memory
using layer_reuse_cb = std::function<int32_t(int32_t il)>;
virtual ~llama_memory_i() = default;
// split the input batch into a set of ubatches and verify that they can fit into the cache

View File

@ -1115,6 +1115,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
hparams.set_swa_pattern(5);
hparams.n_layer_kv_from_start = 20;
hparams.rope_freq_base_train_swa = 10000.0f;
hparams.rope_freq_scale_train_swa = 1.0f;
hparams.f_attention_scale = 1.0f;
@ -1474,12 +1475,15 @@ void llama_model::load_hparams(llama_model_loader & ml) {
// Expert gating function (GLM-4.5 uses sigmoid)
ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) {
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID;
hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID;
}
// NextN/MTP parameters
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);
// TODO: when MTP is implemented, this should probably be updated if needed
hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers;
switch (hparams.n_layer) {
case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer)
case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer)
@ -10524,7 +10528,6 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
const int64_t n_embd_altup;
const int64_t n_altup;
const int i_altup_act;
const int n_layer_kv = 20; // number of layers having KV [KV_REUSE]
const int n_layer_sparsity = 10; // number of layers using activation sparsity
const float f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95)
@ -10574,8 +10577,6 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
for (int il = 0; il < n_layer; ++il) {
// this block is made to be closely resemble Gemma3p5DecoderLayer on python code
const bool has_kv = (il < n_layer_kv);
const float freq_base_l = model.get_rope_freq_base (cparams, il);
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
@ -10595,7 +10596,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
ggml_tensor * laurel_out = laurel(cur, il); // [n_embd, n_tokens]
// self-attention
if (has_kv) {
if (hparams.has_kv(il)) {
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
@ -10635,7 +10636,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
model.layers[il].wo, NULL,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
} else {
// no KV layers
// reuse KV cache of earlier layers
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
@ -18256,12 +18257,12 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
if (llm_arch_is_recurrent(arch)) {
res = new llama_memory_recurrent(
*this,
nullptr,
GGML_TYPE_F32,
GGML_TYPE_F32,
cparams.offload_kqv,
std::max((uint32_t) 1, cparams.n_seq_max),
cparams.n_seq_max);
cparams.n_seq_max,
nullptr);
} else if (llm_arch_is_hybrid(arch)) {
const auto padding = llama_kv_cache::get_padding(cparams);
@ -18302,6 +18303,18 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
llama_memory_i::layer_reuse_cb reuse = nullptr;
if (arch == LLM_ARCH_GEMMA3N) {
reuse = [&](int32_t il) {
if (il >= (int32_t) hparams.n_layer_kv_from_start) {
return (int32_t) hparams.n_layer_kv_from_start - (hparams.is_swa(il) ? 2 : 1);
}
return -1;
};
}
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
GGML_ASSERT(hparams.is_swa_any());
@ -18316,13 +18329,14 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
n_ctx_per_stream,
cparams.n_seq_max,
cparams.n_ubatch,
padding);
padding,
nullptr,
reuse);
} else {
GGML_ASSERT(!hparams.is_swa_any());
res = new llama_kv_cache(
*this,
nullptr,
params.type_k,
params.type_v,
!cparams.flash_attn,
@ -18332,7 +18346,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
cparams.n_seq_max,
padding,
hparams.n_swa,
hparams.swa_type);
hparams.swa_type,
nullptr,
nullptr);
}
}
}