From 744c0c7310aad90e99a29c5739e4ee317fb6a748 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 1 Apr 2026 16:58:01 +0300 Subject: [PATCH] llama : rotate activations for better quantization (#21038) * llama : rotate activations for better quantization * cont : rotate V more + refactor * cont : rotate caches separately + support non-power-of-2 head sizes * cont : simplify * cont : add reference for V rotation * cont : refactor * cont : support context shift * cont : consolidate * cont : dedup + allow different types for the rotation matrix * cont : add env variable to disable rotation * cont : simplify attn rot kv cache logic + rename env * cont : pre-compute the Hadamard matrices --- src/llama-graph.cpp | 119 ++++++++++++++++++----- src/llama-graph.h | 8 ++ src/llama-kv-cache.cpp | 210 ++++++++++++++++++++++++++++++++++++++++- src/llama-kv-cache.h | 26 +++++ 4 files changed, 337 insertions(+), 26 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index c2833b75ce..0e7d96ca10 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -19,7 +19,7 @@ // dedup helpers -static ggml_tensor * build_kq_mask( +static ggml_tensor * build_attn_inp_kq_mask( ggml_context * ctx, const llama_kv_cache_context * mctx, const llama_ubatch & ubatch, @@ -28,7 +28,11 @@ static ggml_tensor * build_kq_mask( const auto n_tokens = ubatch.n_tokens; const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; - return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + ggml_tensor * res = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + ggml_set_input(res); + ggml_set_name(res, "attn_inp_kq_mask"); + + return res; } static bool can_reuse_kq_mask( @@ -52,6 +56,21 @@ static bool can_reuse_kq_mask( // impl +static ggml_tensor * ggml_mul_mat_aux( + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * rot) { + const auto n = rot->ne[0]; + + ggml_tensor * res; + + res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n); + res = ggml_mul_mat (ctx, rot, res); + res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); + + return res; +} + void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { if (ubatch->token) { const int64_t n_tokens = ubatch->n_tokens; @@ -429,6 +448,14 @@ void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) { mctx->set_input_v_idxs(self_v_idxs, ubatch); mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + + if (self_k_rot) { + mctx->set_input_k_rot(self_k_rot); + } + + if (self_v_rot) { + mctx->set_input_v_rot(self_v_rot); + } } bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) { @@ -476,6 +503,14 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch); mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + + if (self_k_rot) { + mctx->get_base()->set_input_k_rot(self_k_rot); + } + + if (self_v_rot) { + mctx->get_base()->set_input_v_rot(self_v_rot); + } } bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { @@ -532,6 +567,14 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); + if (inp_attn->self_k_rot) { + mctx->get_attn()->set_input_k_rot(inp_attn->self_k_rot); + } + + if (inp_attn->self_v_rot) { + mctx->get_attn()->set_input_v_rot(inp_attn->self_v_rot); + } + const int64_t n_rs = mctx->get_recr()->get_n_rs(); if (inp_rs->s_copy) { @@ -630,6 +673,14 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) { attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn); } + if (inp_attn->self_k_rot) { + attn_ctx->get_base()->set_input_k_rot(inp_attn->self_k_rot); + } + + if (inp_attn->self_v_rot) { + attn_ctx->get_base()->set_input_v_rot(inp_attn->self_v_rot); + } + const int64_t n_rs = mctx->get_recr()->get_n_rs(); if (inp_rs->s_copy) { @@ -2002,13 +2053,13 @@ static std::unique_ptr build_attn_inp_kv_impl( inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams); - - ggml_set_input(inp->self_kq_mask); - + inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; } + inp->self_k_rot = mctx_cur->build_input_k_rot(ctx0); + inp->self_v_rot = mctx_cur->build_input_v_rot(ctx0); + return inp; } @@ -2034,6 +2085,15 @@ ggml_tensor * llm_graph_context::build_attn( int il) const { GGML_ASSERT(v_mla == nullptr); + if (inp->self_k_rot) { + q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot); + k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot); + } + + if (inp->self_v_rot) { + v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot); + } + // these nodes are added to the graph together so that they are not reordered // by doing so, the number of splits in the graph is reduced // expand k later to enable rope fusion which directly writes into k-v cache @@ -2061,6 +2121,10 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il); cb(cur, "kqv_out", il); + if (inp->self_v_rot) { + cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot); + } + if (wo) { cur = build_lora_mm(wo, cur); if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) { @@ -2090,9 +2154,7 @@ static std::unique_ptr build_attn_inp_k_impl( inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); - inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams); - ggml_set_input(inp->self_kq_mask); - + inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; } @@ -2171,6 +2233,18 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * v_mla, float kq_scale, int il) const { + if (inp->self_k_rot) { + q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot); + if (k_cur) { + k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot); + } + } + if (inp->self_v_rot) { + if (v_cur) { + v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot); + } + } + // these nodes are added to the graph together so that they are not reordered // by doing so, the number of splits in the graph is reduced ggml_build_forward_expand(gf, q_cur); @@ -2211,6 +2285,10 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il); cb(cur, "kqv_out", il); + if (inp->self_v_rot) { + cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot); + } + if (wo) { cur = build_lora_mm(wo, cur); } @@ -2293,12 +2371,8 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams); - ggml_set_input(inp->self_kq_mask); - ggml_set_name(inp->self_kq_mask, "self_kq_mask"); - + inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; - ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv"); } { @@ -2307,14 +2381,13 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask_swa = build_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams); - ggml_set_input(inp->self_kq_mask_swa); - ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa"); - + inp->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams); inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; - ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv"); } + inp->self_k_rot = mctx_cur->get_base()->build_input_k_rot(ctx0); + inp->self_v_rot = mctx_cur->get_base()->build_input_v_rot(ctx0); + return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp)); } @@ -2473,9 +2546,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch); inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch); - inp_attn->self_kq_mask = build_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams); - ggml_set_input(inp_attn->self_kq_mask); - + inp_attn->self_kq_mask = build_attn_inp_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams); inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask; } @@ -2483,9 +2554,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch); inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch); - inp_attn->self_kq_mask_swa = build_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams); - ggml_set_input(inp_attn->self_kq_mask_swa); - + inp_attn->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams); inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa; } diff --git a/src/llama-graph.h b/src/llama-graph.h index 4855685ef7..bb0ad75198 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -308,6 +308,10 @@ public: ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + // note: assumes v_rot^ == I + ggml_tensor * self_k_rot = nullptr; + ggml_tensor * self_v_rot = nullptr; + // note: these have to be copies because in order to be able to reuse a graph, its inputs // need to carry these parameters with them. otherwise, they can point to freed // llm_graph_params from a previous batch, causing stack-use-after-return @@ -384,6 +388,10 @@ public: ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + // note: using same rotation matrices for both base and swa cache + ggml_tensor * self_k_rot = nullptr; + ggml_tensor * self_v_rot = nullptr; + const llama_hparams hparams; const llama_cparams cparams; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 5f57ba9e1d..3e0fd3107f 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -13,6 +13,65 @@ #include #include +static bool ggml_is_power_of_2(int n) { + return (n & (n - 1)) == 0; +} + +// orthonormal Walsh-Hadamard rotation matrix +// note: res^2 == I +static void ggml_gen_hadamard(ggml_tensor * tensor) { + assert(tensor->type == GGML_TYPE_F32); + + const int n = tensor->ne[0]; + + assert(ggml_is_power_of_2(n)); + assert(tensor->ne[1] == n); + assert(tensor->ne[2] == 1); + assert(tensor->ne[3] == 1); + + std::vector data_f32; + + float * data = (float *) tensor->data; + + if (tensor->type != GGML_TYPE_F32) { + data_f32.resize(n*n); + data = data_f32.data(); + } + + data[0*n + 0] = 1.0 / sqrtf(n); + + for (int s = 1; s < n; s *= 2) { + for (int i = 0; i < s; i++) { + for (int j = 0; j < s; j++) { + const float val = data[i*n + j]; + + data[(i + s)*n + (j )] = val; + data[(i )*n + (j + s)] = val; + data[(i + s)*n + (j + s)] = -val; + } + } + } + + if (tensor->type != GGML_TYPE_F32) { + ggml_quantize_chunk(tensor->type, data, tensor->data, 0, 1, n*n, nullptr); + } +} + +static ggml_tensor * ggml_mul_mat_aux( + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * rot) { + const auto n = rot->ne[0]; + + ggml_tensor * res; + + res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n); + res = ggml_mul_mat (ctx, rot, res); + res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); + + return res; +} + // // llama_kv_cache // @@ -209,6 +268,48 @@ llama_kv_cache::llama_kv_cache( ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); } + const char * LLAMA_ATTN_ROT_DISABLE = getenv("LLAMA_ATTN_ROT_DISABLE"); + const bool attn_rot_disable = LLAMA_ATTN_ROT_DISABLE ? atoi(LLAMA_ATTN_ROT_DISABLE) : false; + if (attn_rot_disable) { + LLAMA_LOG_WARN("%s: attention rotation force disabled (LLAMA_ATTN_ROT_DISABLE)\n", __func__); + } + + attn_rot_k = + !attn_rot_disable && + ggml_is_quantized(type_k) && + !hparams.is_n_embd_k_gqa_variable() && + hparams.n_embd_head_k() % 64 == 0; + + attn_rot_v = + !attn_rot_disable && + ggml_is_quantized(type_v) && + !hparams.is_n_embd_v_gqa_variable() && + hparams.n_embd_head_v() % 64 == 0; + + LLAMA_LOG_INFO("%s: attn_rot_k = %d\n", __func__, attn_rot_k); + LLAMA_LOG_INFO("%s: attn_rot_v = %d\n", __func__, attn_rot_v); + + // pre-compute the haramard matrices and keep them in host memory + // TODO: in the future, we can make copies in the backend buffers to avoid host -> device transfers + if (attn_rot_k || attn_rot_v) { + for (int64_t n = 64; n <= std::max(hparams.n_embd_head_k(), hparams.n_embd_head_v()); n *= 2) { + attn_rot_hadamard[n] = std::vector(n*n); + + ggml_init_params params = { + /* .mem_size = */ 1*ggml_tensor_overhead(), + /* .mem_buffer = */ nullptr, + /* .no_alloc = */ true, + }; + + ggml_context_ptr ctx { ggml_init(params) }; + + ggml_tensor * tmp = ggml_new_tensor_2d(ctx.get(), GGML_TYPE_F32, n, n); + tmp->data = attn_rot_hadamard[n].data(); + + ggml_gen_hadamard(tmp); + } + } + const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG"); debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0; } @@ -1004,6 +1105,14 @@ bool llama_kv_cache::get_has_shift() const { return result; } +ggml_type llama_kv_cache::type_k() const { + return layers[0].k->type; +} + +ggml_type llama_kv_cache::type_v() const { + return layers[0].v->type; +} + uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const { uint32_t result = 0; @@ -1189,6 +1298,47 @@ ggml_tensor * llama_kv_cache::build_input_v_idxs(ggml_context * ctx, const llama return v_idxs; } +ggml_tensor * llama_kv_cache::build_input_k_rot(ggml_context * ctx) const { + ggml_tensor * res = nullptr; + + if (attn_rot_k) { + int nrot = 64; + + // TODO: investigate if using the smallest rotation matrix is beneficial also for K (similar as for V) + // ref: https://github.com/ggml-org/llama.cpp/pull/21038#issuecomment-4141323088 + do { + nrot *= 2; + } while (hparams.n_embd_head_k() % nrot == 0); + nrot /= 2; + + res = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nrot, nrot); + ggml_set_input(res); + ggml_set_name(res, "attn_inp_k_rot"); + } + + return res; +} + +ggml_tensor * llama_kv_cache::build_input_v_rot(ggml_context * ctx) const { + ggml_tensor * res = nullptr; + + if (attn_rot_v) { + int nrot = 64; + // using smaller rotation matrices for V seems beneficial + // ref: https://github.com/ggml-org/llama.cpp/pull/21038#issuecomment-4146397570 + //do { + // nrot *= 2; + //} while (hparams.n_embd_head_v() % nrot == 0); + //nrot /= 2; + + res = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nrot, nrot); + ggml_set_input(res); + ggml_set_name(res, "attn_inp_v_rot"); + } + + return res; +} + void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const { const uint32_t n_tokens = ubatch->n_tokens; GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream()); @@ -1507,6 +1657,24 @@ void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch } } +void llama_kv_cache::set_input_k_rot(ggml_tensor * dst) const { + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + + const auto n_rot = dst->ne[0]; + GGML_ASSERT(attn_rot_hadamard.count(dst->ne[0])); + + memcpy(dst->data, attn_rot_hadamard.at(n_rot).data(), ggml_nbytes(dst)); +} + +void llama_kv_cache::set_input_v_rot(ggml_tensor * dst) const { + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + + const auto n_rot = dst->ne[0]; + GGML_ASSERT(attn_rot_hadamard.count(dst->ne[0])); + + memcpy(dst->data, attn_rot_hadamard.at(n_rot).data(), ggml_nbytes(dst)); +} + size_t llama_kv_cache::total_size() const { size_t size = 0; @@ -1542,6 +1710,7 @@ ggml_tensor * llama_kv_cache::build_rope_shift( ggml_context * ctx, ggml_tensor * cur, ggml_tensor * shift, + ggml_tensor * rot, ggml_tensor * factors, float freq_base, float freq_scale, @@ -1567,10 +1736,16 @@ ggml_tensor * llama_kv_cache::build_rope_shift( // dequantize to f32 -> RoPE -> quantize back tmp = ggml_cast(ctx, cur, GGML_TYPE_F32); + // rotate back + tmp = ggml_mul_mat_aux(ctx, tmp, rot); + tmp = ggml_rope_ext(ctx, tmp, shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); + // rotate fwd + tmp = ggml_mul_mat_aux(ctx, tmp, rot); + tmp = ggml_cpy(ctx, tmp, cur); } else { // we rotate only the first n_rot dimensions @@ -1591,6 +1766,9 @@ public: ggml_tensor * k_shift; // I32 [kv_size*n_stream] + // note: assumes k_rot^2 == I + ggml_tensor * k_rot = nullptr; + const llama_kv_cache * kv_self; }; @@ -1600,6 +1778,10 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { if (k_shift) { kv_self->set_input_k_shift(k_shift); } + + if (k_rot) { + kv_self->set_input_k_rot(k_rot); + } } ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const { @@ -1611,6 +1793,8 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream); ggml_set_input(inp->k_shift); + inp->k_rot = build_input_k_rot(ctx); + const auto & cparams = lctx->get_cparams(); for (const auto & layer : layers) { @@ -1635,7 +1819,7 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co ggml_row_size(layer.k->type, n_embd_k_gqa), ggml_row_size(layer.k->type, n_embd_nope)); - ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, il); + ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, inp->k_rot, rope_factors, freq_base_l, freq_scale_l, il); ggml_build_forward_expand(gf, cur); } @@ -2239,6 +2423,14 @@ uint32_t llama_kv_cache_context::get_n_kv() const { return n_kv; } +ggml_type llama_kv_cache_context::type_k() const { + return kv->type_k(); +} + +ggml_type llama_kv_cache_context::type_v() const { + return kv->type_v(); +} + ggml_tensor * llama_kv_cache_context::get_k(ggml_context * ctx, int32_t il) const { return kv->get_k(ctx, il, n_kv, sinfos[i_cur]); } @@ -2263,6 +2455,14 @@ ggml_tensor * llama_kv_cache_context::build_input_v_idxs(ggml_context * ctx, con return kv->build_input_v_idxs(ctx, ubatch); } +ggml_tensor * llama_kv_cache_context::build_input_k_rot(ggml_context * ctx) const { + return kv->build_input_k_rot(ctx); +} + +ggml_tensor * llama_kv_cache_context::build_input_v_rot(ggml_context * ctx) const { + return kv->build_input_v_rot(ctx); +} + void llama_kv_cache_context::set_input_k_shift(ggml_tensor * dst) const { kv->set_input_k_shift(dst); } @@ -2282,3 +2482,11 @@ void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ub void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { kv->set_input_pos_bucket(dst, ubatch); } + +void llama_kv_cache_context::set_input_k_rot(ggml_tensor * dst) const { + kv->set_input_k_rot(dst); +} + +void llama_kv_cache_context::set_input_v_rot(ggml_tensor * dst) const { + kv->set_input_v_rot(dst); +} diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 90a0610c49..d4569a06f7 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -152,6 +152,9 @@ public: bool get_has_shift() const; + ggml_type type_k() const; + ggml_type type_v() const; + // // graph_build API // @@ -191,6 +194,9 @@ public: ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; + ggml_tensor * build_input_k_rot(ggml_context * ctx) const; + ggml_tensor * build_input_v_rot(ggml_context * ctx) const; + void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; @@ -199,6 +205,9 @@ public: void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; + void set_input_k_rot(ggml_tensor * dst) const; + void set_input_v_rot(ggml_tensor * dst) const; + private: const llama_model & model; const llama_hparams & hparams; @@ -226,6 +235,13 @@ private: // SWA const uint32_t n_swa = 0; + // env: LLAMA_ATTN_ROT_DISABLE + bool attn_rot_k = false; + bool attn_rot_v = false; + + // pre-computed hadamard martrices + std::unordered_map> attn_rot_hadamard; + // env: LLAMA_KV_CACHE_DEBUG int debug = 0; @@ -262,6 +278,7 @@ private: ggml_context * ctx, ggml_tensor * cur, ggml_tensor * shift, + ggml_tensor * rot, ggml_tensor * factors, float freq_base, float freq_scale, @@ -328,6 +345,9 @@ public: uint32_t get_n_kv() const; + ggml_type type_k() const; + ggml_type type_v() const; + // get views of the current state of the cache ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; @@ -347,6 +367,9 @@ public: ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; + ggml_tensor * build_input_k_rot(ggml_context * ctx) const; + ggml_tensor * build_input_v_rot(ggml_context * ctx) const; + void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const; void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const; @@ -354,6 +377,9 @@ public: void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; + void set_input_k_rot(ggml_tensor * dst) const; + void set_input_v_rot(ggml_tensor * dst) const; + private: llama_memory_status status;