From 4eb19514dd2984662f13aacbb052c559c8fde3b1 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 7 Apr 2026 20:31:28 +0300 Subject: [PATCH] kv-cache : support attention rotation for heterogeneous iSWA (#21513) * kv-cache : support attention rotation for heterogeneous iSWA * cont : remove assert --- src/llama-graph.cpp | 40 +++++++++++++++++++++++++++++++--------- src/llama-graph.h | 6 ++++-- src/llama-kv-cache.cpp | 24 ++++++++++++++++++------ src/llama-kv-cache.h | 5 +++++ 4 files changed, 58 insertions(+), 17 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 0e7d96ca10..d6f5c5eab5 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -511,6 +511,14 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { if (self_v_rot) { mctx->get_base()->set_input_v_rot(self_v_rot); } + + if (self_k_rot_swa) { + mctx->get_swa()->set_input_k_rot(self_k_rot_swa); + } + + if (self_v_rot_swa) { + mctx->get_swa()->set_input_v_rot(self_v_rot_swa); + } } bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { @@ -681,6 +689,14 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) { attn_ctx->get_base()->set_input_v_rot(inp_attn->self_v_rot); } + if (inp_attn->self_k_rot_swa) { + attn_ctx->get_swa()->set_input_k_rot(inp_attn->self_k_rot_swa); + } + + if (inp_attn->self_v_rot_swa) { + attn_ctx->get_swa()->set_input_v_rot(inp_attn->self_v_rot_swa); + } + const int64_t n_rs = mctx->get_recr()->get_n_rs(); if (inp_rs->s_copy) { @@ -2233,15 +2249,20 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * v_mla, float kq_scale, int il) const { - if (inp->self_k_rot) { - q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot); + const bool is_swa = hparams.is_swa(il); + + auto * k_rot = is_swa ? inp->self_k_rot_swa : inp->self_k_rot; + auto * v_rot = is_swa ? inp->self_v_rot_swa : inp->self_v_rot; + + if (k_rot) { + q_cur = ggml_mul_mat_aux(ctx0, q_cur, k_rot); if (k_cur) { - k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot); + k_cur = ggml_mul_mat_aux(ctx0, k_cur, k_rot); } } - if (inp->self_v_rot) { + if (v_rot) { if (v_cur) { - v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot); + v_cur = ggml_mul_mat_aux(ctx0, v_cur, v_rot); } } @@ -2259,8 +2280,6 @@ ggml_tensor * llm_graph_context::build_attn( const auto * mctx_iswa = inp->mctx; - const bool is_swa = hparams.is_swa(il); - const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base(); // optionally store to KV cache @@ -2285,8 +2304,8 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il); cb(cur, "kqv_out", il); - if (inp->self_v_rot) { - cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot); + if (v_rot) { + cur = ggml_mul_mat_aux(ctx0, cur, v_rot); } if (wo) { @@ -2388,6 +2407,9 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_k_rot = mctx_cur->get_base()->build_input_k_rot(ctx0); inp->self_v_rot = mctx_cur->get_base()->build_input_v_rot(ctx0); + inp->self_k_rot_swa = mctx_cur->get_swa()->build_input_k_rot(ctx0); + inp->self_v_rot_swa = mctx_cur->get_swa()->build_input_v_rot(ctx0); + return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp)); } diff --git a/src/llama-graph.h b/src/llama-graph.h index bb0ad75198..29e78451fb 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -308,7 +308,7 @@ public: ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] - // note: assumes v_rot^ == I + // note: assumes v_rot^2 == I ggml_tensor * self_k_rot = nullptr; ggml_tensor * self_v_rot = nullptr; @@ -388,10 +388,12 @@ public: ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] - // note: using same rotation matrices for both base and swa cache ggml_tensor * self_k_rot = nullptr; ggml_tensor * self_v_rot = nullptr; + ggml_tensor * self_k_rot_swa = nullptr; + ggml_tensor * self_v_rot_swa = nullptr; + const llama_hparams hparams; const llama_cparams cparams; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 3e0fd3107f..09102f549c 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -169,6 +169,18 @@ llama_kv_cache::llama_kv_cache( continue; } + if (n_embd_head_k_all == 0) { + n_embd_head_k_all = (int32_t) hparams.n_embd_head_k(il); + } else if (n_embd_head_k_all > 0 && n_embd_head_k_all != (int32_t) hparams.n_embd_head_k(il)) { + n_embd_head_k_all = -1; + } + + if (n_embd_head_v_all == 0) { + n_embd_head_v_all = (int32_t) hparams.n_embd_head_v(il); + } else if (n_embd_head_v_all > 0 && n_embd_head_v_all != (int32_t) hparams.n_embd_head_v(il)) { + n_embd_head_v_all = -1; + } + // [TAG_V_CACHE_VARIABLE] const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max(); @@ -276,23 +288,23 @@ llama_kv_cache::llama_kv_cache( attn_rot_k = !attn_rot_disable && + n_embd_head_k_all > 0 && ggml_is_quantized(type_k) && - !hparams.is_n_embd_k_gqa_variable() && hparams.n_embd_head_k() % 64 == 0; attn_rot_v = !attn_rot_disable && + n_embd_head_v_all > 0 && ggml_is_quantized(type_v) && - !hparams.is_n_embd_v_gqa_variable() && hparams.n_embd_head_v() % 64 == 0; - LLAMA_LOG_INFO("%s: attn_rot_k = %d\n", __func__, attn_rot_k); - LLAMA_LOG_INFO("%s: attn_rot_v = %d\n", __func__, attn_rot_v); + LLAMA_LOG_INFO("%s: attn_rot_k = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_k, n_embd_head_k_all); + LLAMA_LOG_INFO("%s: attn_rot_v = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_v, n_embd_head_v_all); // pre-compute the haramard matrices and keep them in host memory // TODO: in the future, we can make copies in the backend buffers to avoid host -> device transfers if (attn_rot_k || attn_rot_v) { - for (int64_t n = 64; n <= std::max(hparams.n_embd_head_k(), hparams.n_embd_head_v()); n *= 2) { + for (int64_t n = 64; n <= std::max(n_embd_head_k_all, n_embd_head_v_all); n *= 2) { attn_rot_hadamard[n] = std::vector(n*n); ggml_init_params params = { @@ -1308,7 +1320,7 @@ ggml_tensor * llama_kv_cache::build_input_k_rot(ggml_context * ctx) const { // ref: https://github.com/ggml-org/llama.cpp/pull/21038#issuecomment-4141323088 do { nrot *= 2; - } while (hparams.n_embd_head_k() % nrot == 0); + } while (n_embd_head_k_all % nrot == 0); nrot /= 2; res = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nrot, nrot); diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index d4569a06f7..0b62dc7b23 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -239,6 +239,11 @@ private: bool attn_rot_k = false; bool attn_rot_v = false; + // if all layers participating in the cache have constant head size, the value is stored here + // otherwise the value is -1 + int32_t n_embd_head_k_all = 0; + int32_t n_embd_head_v_all = 0; + // pre-computed hadamard martrices std::unordered_map> attn_rot_hadamard;