llama : rotate activations for better quantization (#21038)

* llama : rotate activations for better quantization

* cont : rotate V more + refactor

* cont : rotate caches separately + support non-power-of-2 head sizes

* cont : simplify

* cont : add reference for V rotation

* cont : refactor

* cont : support context shift

* cont : consolidate

* cont : dedup + allow different types for the rotation matrix

* cont : add env variable to disable rotation

* cont : simplify attn rot kv cache logic + rename env

* cont : pre-compute the Hadamard matrices
This commit is contained in:
Georgi Gerganov 2026-04-01 16:58:01 +03:00 committed by GitHub
parent 0356e33aaf
commit 744c0c7310
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 337 additions and 26 deletions

View File

@ -19,7 +19,7 @@
// dedup helpers
static ggml_tensor * build_kq_mask(
static ggml_tensor * build_attn_inp_kq_mask(
ggml_context * ctx,
const llama_kv_cache_context * mctx,
const llama_ubatch & ubatch,
@ -28,7 +28,11 @@ static ggml_tensor * build_kq_mask(
const auto n_tokens = ubatch.n_tokens;
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
ggml_tensor * res = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
ggml_set_input(res);
ggml_set_name(res, "attn_inp_kq_mask");
return res;
}
static bool can_reuse_kq_mask(
@ -52,6 +56,21 @@ static bool can_reuse_kq_mask(
// impl
static ggml_tensor * ggml_mul_mat_aux(
ggml_context * ctx,
ggml_tensor * cur,
ggml_tensor * rot) {
const auto n = rot->ne[0];
ggml_tensor * res;
res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n);
res = ggml_mul_mat (ctx, rot, res);
res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]);
return res;
}
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
if (ubatch->token) {
const int64_t n_tokens = ubatch->n_tokens;
@ -429,6 +448,14 @@ void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
mctx->set_input_v_idxs(self_v_idxs, ubatch);
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
if (self_k_rot) {
mctx->set_input_k_rot(self_k_rot);
}
if (self_v_rot) {
mctx->set_input_v_rot(self_v_rot);
}
}
bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
@ -476,6 +503,14 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
if (self_k_rot) {
mctx->get_base()->set_input_k_rot(self_k_rot);
}
if (self_v_rot) {
mctx->get_base()->set_input_v_rot(self_v_rot);
}
}
bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
@ -532,6 +567,14 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
if (inp_attn->self_k_rot) {
mctx->get_attn()->set_input_k_rot(inp_attn->self_k_rot);
}
if (inp_attn->self_v_rot) {
mctx->get_attn()->set_input_v_rot(inp_attn->self_v_rot);
}
const int64_t n_rs = mctx->get_recr()->get_n_rs();
if (inp_rs->s_copy) {
@ -630,6 +673,14 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
}
if (inp_attn->self_k_rot) {
attn_ctx->get_base()->set_input_k_rot(inp_attn->self_k_rot);
}
if (inp_attn->self_v_rot) {
attn_ctx->get_base()->set_input_v_rot(inp_attn->self_v_rot);
}
const int64_t n_rs = mctx->get_recr()->get_n_rs();
if (inp_rs->s_copy) {
@ -2002,13 +2053,13 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
}
inp->self_k_rot = mctx_cur->build_input_k_rot(ctx0);
inp->self_v_rot = mctx_cur->build_input_v_rot(ctx0);
return inp;
}
@ -2034,6 +2085,15 @@ ggml_tensor * llm_graph_context::build_attn(
int il) const {
GGML_ASSERT(v_mla == nullptr);
if (inp->self_k_rot) {
q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot);
k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot);
}
if (inp->self_v_rot) {
v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot);
}
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
// expand k later to enable rope fusion which directly writes into k-v cache
@ -2061,6 +2121,10 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
cb(cur, "kqv_out", il);
if (inp->self_v_rot) {
cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot);
}
if (wo) {
cur = build_lora_mm(wo, cur);
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) {
@ -2090,9 +2154,7 @@ static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl(
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
ggml_set_input(inp->self_kq_mask);
inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
}
@ -2171,6 +2233,18 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * v_mla,
float kq_scale,
int il) const {
if (inp->self_k_rot) {
q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot);
if (k_cur) {
k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot);
}
}
if (inp->self_v_rot) {
if (v_cur) {
v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot);
}
}
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
ggml_build_forward_expand(gf, q_cur);
@ -2211,6 +2285,10 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
cb(cur, "kqv_out", il);
if (inp->self_v_rot) {
cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot);
}
if (wo) {
cur = build_lora_mm(wo, cur);
}
@ -2293,12 +2371,8 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams);
ggml_set_input(inp->self_kq_mask);
ggml_set_name(inp->self_kq_mask, "self_kq_mask");
inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams);
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv");
}
{
@ -2307,14 +2381,13 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
inp->self_kq_mask_swa = build_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams);
ggml_set_input(inp->self_kq_mask_swa);
ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa");
inp->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams);
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv");
}
inp->self_k_rot = mctx_cur->get_base()->build_input_k_rot(ctx0);
inp->self_v_rot = mctx_cur->get_base()->build_input_v_rot(ctx0);
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
}
@ -2473,9 +2546,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa()
inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch);
inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
inp_attn->self_kq_mask = build_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams);
ggml_set_input(inp_attn->self_kq_mask);
inp_attn->self_kq_mask = build_attn_inp_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams);
inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask;
}
@ -2483,9 +2554,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa()
inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch);
inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
inp_attn->self_kq_mask_swa = build_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams);
ggml_set_input(inp_attn->self_kq_mask_swa);
inp_attn->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams);
inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa;
}

View File

@ -308,6 +308,10 @@ public:
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
// note: assumes v_rot^ == I
ggml_tensor * self_k_rot = nullptr;
ggml_tensor * self_v_rot = nullptr;
// note: these have to be copies because in order to be able to reuse a graph, its inputs
// need to carry these parameters with them. otherwise, they can point to freed
// llm_graph_params from a previous batch, causing stack-use-after-return
@ -384,6 +388,10 @@ public:
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
// note: using same rotation matrices for both base and swa cache
ggml_tensor * self_k_rot = nullptr;
ggml_tensor * self_v_rot = nullptr;
const llama_hparams hparams;
const llama_cparams cparams;

View File

@ -13,6 +13,65 @@
#include <map>
#include <stdexcept>
static bool ggml_is_power_of_2(int n) {
return (n & (n - 1)) == 0;
}
// orthonormal Walsh-Hadamard rotation matrix
// note: res^2 == I
static void ggml_gen_hadamard(ggml_tensor * tensor) {
assert(tensor->type == GGML_TYPE_F32);
const int n = tensor->ne[0];
assert(ggml_is_power_of_2(n));
assert(tensor->ne[1] == n);
assert(tensor->ne[2] == 1);
assert(tensor->ne[3] == 1);
std::vector<float> data_f32;
float * data = (float *) tensor->data;
if (tensor->type != GGML_TYPE_F32) {
data_f32.resize(n*n);
data = data_f32.data();
}
data[0*n + 0] = 1.0 / sqrtf(n);
for (int s = 1; s < n; s *= 2) {
for (int i = 0; i < s; i++) {
for (int j = 0; j < s; j++) {
const float val = data[i*n + j];
data[(i + s)*n + (j )] = val;
data[(i )*n + (j + s)] = val;
data[(i + s)*n + (j + s)] = -val;
}
}
}
if (tensor->type != GGML_TYPE_F32) {
ggml_quantize_chunk(tensor->type, data, tensor->data, 0, 1, n*n, nullptr);
}
}
static ggml_tensor * ggml_mul_mat_aux(
ggml_context * ctx,
ggml_tensor * cur,
ggml_tensor * rot) {
const auto n = rot->ne[0];
ggml_tensor * res;
res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n);
res = ggml_mul_mat (ctx, rot, res);
res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]);
return res;
}
//
// llama_kv_cache
//
@ -209,6 +268,48 @@ llama_kv_cache::llama_kv_cache(
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
}
const char * LLAMA_ATTN_ROT_DISABLE = getenv("LLAMA_ATTN_ROT_DISABLE");
const bool attn_rot_disable = LLAMA_ATTN_ROT_DISABLE ? atoi(LLAMA_ATTN_ROT_DISABLE) : false;
if (attn_rot_disable) {
LLAMA_LOG_WARN("%s: attention rotation force disabled (LLAMA_ATTN_ROT_DISABLE)\n", __func__);
}
attn_rot_k =
!attn_rot_disable &&
ggml_is_quantized(type_k) &&
!hparams.is_n_embd_k_gqa_variable() &&
hparams.n_embd_head_k() % 64 == 0;
attn_rot_v =
!attn_rot_disable &&
ggml_is_quantized(type_v) &&
!hparams.is_n_embd_v_gqa_variable() &&
hparams.n_embd_head_v() % 64 == 0;
LLAMA_LOG_INFO("%s: attn_rot_k = %d\n", __func__, attn_rot_k);
LLAMA_LOG_INFO("%s: attn_rot_v = %d\n", __func__, attn_rot_v);
// pre-compute the haramard matrices and keep them in host memory
// TODO: in the future, we can make copies in the backend buffers to avoid host -> device transfers
if (attn_rot_k || attn_rot_v) {
for (int64_t n = 64; n <= std::max(hparams.n_embd_head_k(), hparams.n_embd_head_v()); n *= 2) {
attn_rot_hadamard[n] = std::vector<float>(n*n);
ggml_init_params params = {
/* .mem_size = */ 1*ggml_tensor_overhead(),
/* .mem_buffer = */ nullptr,
/* .no_alloc = */ true,
};
ggml_context_ptr ctx { ggml_init(params) };
ggml_tensor * tmp = ggml_new_tensor_2d(ctx.get(), GGML_TYPE_F32, n, n);
tmp->data = attn_rot_hadamard[n].data();
ggml_gen_hadamard(tmp);
}
}
const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
}
@ -1004,6 +1105,14 @@ bool llama_kv_cache::get_has_shift() const {
return result;
}
ggml_type llama_kv_cache::type_k() const {
return layers[0].k->type;
}
ggml_type llama_kv_cache::type_v() const {
return layers[0].v->type;
}
uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
uint32_t result = 0;
@ -1189,6 +1298,47 @@ ggml_tensor * llama_kv_cache::build_input_v_idxs(ggml_context * ctx, const llama
return v_idxs;
}
ggml_tensor * llama_kv_cache::build_input_k_rot(ggml_context * ctx) const {
ggml_tensor * res = nullptr;
if (attn_rot_k) {
int nrot = 64;
// TODO: investigate if using the smallest rotation matrix is beneficial also for K (similar as for V)
// ref: https://github.com/ggml-org/llama.cpp/pull/21038#issuecomment-4141323088
do {
nrot *= 2;
} while (hparams.n_embd_head_k() % nrot == 0);
nrot /= 2;
res = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nrot, nrot);
ggml_set_input(res);
ggml_set_name(res, "attn_inp_k_rot");
}
return res;
}
ggml_tensor * llama_kv_cache::build_input_v_rot(ggml_context * ctx) const {
ggml_tensor * res = nullptr;
if (attn_rot_v) {
int nrot = 64;
// using smaller rotation matrices for V seems beneficial
// ref: https://github.com/ggml-org/llama.cpp/pull/21038#issuecomment-4146397570
//do {
// nrot *= 2;
//} while (hparams.n_embd_head_v() % nrot == 0);
//nrot /= 2;
res = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nrot, nrot);
ggml_set_input(res);
ggml_set_name(res, "attn_inp_v_rot");
}
return res;
}
void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
const uint32_t n_tokens = ubatch->n_tokens;
GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
@ -1507,6 +1657,24 @@ void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch
}
}
void llama_kv_cache::set_input_k_rot(ggml_tensor * dst) const {
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
const auto n_rot = dst->ne[0];
GGML_ASSERT(attn_rot_hadamard.count(dst->ne[0]));
memcpy(dst->data, attn_rot_hadamard.at(n_rot).data(), ggml_nbytes(dst));
}
void llama_kv_cache::set_input_v_rot(ggml_tensor * dst) const {
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
const auto n_rot = dst->ne[0];
GGML_ASSERT(attn_rot_hadamard.count(dst->ne[0]));
memcpy(dst->data, attn_rot_hadamard.at(n_rot).data(), ggml_nbytes(dst));
}
size_t llama_kv_cache::total_size() const {
size_t size = 0;
@ -1542,6 +1710,7 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
ggml_context * ctx,
ggml_tensor * cur,
ggml_tensor * shift,
ggml_tensor * rot,
ggml_tensor * factors,
float freq_base,
float freq_scale,
@ -1567,10 +1736,16 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
// dequantize to f32 -> RoPE -> quantize back
tmp = ggml_cast(ctx, cur, GGML_TYPE_F32);
// rotate back
tmp = ggml_mul_mat_aux(ctx, tmp, rot);
tmp = ggml_rope_ext(ctx, tmp,
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
// rotate fwd
tmp = ggml_mul_mat_aux(ctx, tmp, rot);
tmp = ggml_cpy(ctx, tmp, cur);
} else {
// we rotate only the first n_rot dimensions
@ -1591,6 +1766,9 @@ public:
ggml_tensor * k_shift; // I32 [kv_size*n_stream]
// note: assumes k_rot^2 == I
ggml_tensor * k_rot = nullptr;
const llama_kv_cache * kv_self;
};
@ -1600,6 +1778,10 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
if (k_shift) {
kv_self->set_input_k_shift(k_shift);
}
if (k_rot) {
kv_self->set_input_k_rot(k_rot);
}
}
ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
@ -1611,6 +1793,8 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
ggml_set_input(inp->k_shift);
inp->k_rot = build_input_k_rot(ctx);
const auto & cparams = lctx->get_cparams();
for (const auto & layer : layers) {
@ -1635,7 +1819,7 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
ggml_row_size(layer.k->type, n_embd_k_gqa),
ggml_row_size(layer.k->type, n_embd_nope));
ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, il);
ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, inp->k_rot, rope_factors, freq_base_l, freq_scale_l, il);
ggml_build_forward_expand(gf, cur);
}
@ -2239,6 +2423,14 @@ uint32_t llama_kv_cache_context::get_n_kv() const {
return n_kv;
}
ggml_type llama_kv_cache_context::type_k() const {
return kv->type_k();
}
ggml_type llama_kv_cache_context::type_v() const {
return kv->type_v();
}
ggml_tensor * llama_kv_cache_context::get_k(ggml_context * ctx, int32_t il) const {
return kv->get_k(ctx, il, n_kv, sinfos[i_cur]);
}
@ -2263,6 +2455,14 @@ ggml_tensor * llama_kv_cache_context::build_input_v_idxs(ggml_context * ctx, con
return kv->build_input_v_idxs(ctx, ubatch);
}
ggml_tensor * llama_kv_cache_context::build_input_k_rot(ggml_context * ctx) const {
return kv->build_input_k_rot(ctx);
}
ggml_tensor * llama_kv_cache_context::build_input_v_rot(ggml_context * ctx) const {
return kv->build_input_v_rot(ctx);
}
void llama_kv_cache_context::set_input_k_shift(ggml_tensor * dst) const {
kv->set_input_k_shift(dst);
}
@ -2282,3 +2482,11 @@ void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ub
void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
kv->set_input_pos_bucket(dst, ubatch);
}
void llama_kv_cache_context::set_input_k_rot(ggml_tensor * dst) const {
kv->set_input_k_rot(dst);
}
void llama_kv_cache_context::set_input_v_rot(ggml_tensor * dst) const {
kv->set_input_v_rot(dst);
}

View File

@ -152,6 +152,9 @@ public:
bool get_has_shift() const;
ggml_type type_k() const;
ggml_type type_v() const;
//
// graph_build API
//
@ -191,6 +194,9 @@ public:
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
ggml_tensor * build_input_k_rot(ggml_context * ctx) const;
ggml_tensor * build_input_v_rot(ggml_context * ctx) const;
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
@ -199,6 +205,9 @@ public:
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
void set_input_k_rot(ggml_tensor * dst) const;
void set_input_v_rot(ggml_tensor * dst) const;
private:
const llama_model & model;
const llama_hparams & hparams;
@ -226,6 +235,13 @@ private:
// SWA
const uint32_t n_swa = 0;
// env: LLAMA_ATTN_ROT_DISABLE
bool attn_rot_k = false;
bool attn_rot_v = false;
// pre-computed hadamard martrices
std::unordered_map<int64_t, std::vector<float>> attn_rot_hadamard;
// env: LLAMA_KV_CACHE_DEBUG
int debug = 0;
@ -262,6 +278,7 @@ private:
ggml_context * ctx,
ggml_tensor * cur,
ggml_tensor * shift,
ggml_tensor * rot,
ggml_tensor * factors,
float freq_base,
float freq_scale,
@ -328,6 +345,9 @@ public:
uint32_t get_n_kv() const;
ggml_type type_k() const;
ggml_type type_v() const;
// get views of the current state of the cache
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
@ -347,6 +367,9 @@ public:
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
ggml_tensor * build_input_k_rot(ggml_context * ctx) const;
ggml_tensor * build_input_v_rot(ggml_context * ctx) const;
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
@ -354,6 +377,9 @@ public:
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
void set_input_k_rot(ggml_tensor * dst) const;
void set_input_v_rot(ggml_tensor * dst) const;
private:
llama_memory_status status;