llama : rotate activations for better quantization (#21038)
* llama : rotate activations for better quantization * cont : rotate V more + refactor * cont : rotate caches separately + support non-power-of-2 head sizes * cont : simplify * cont : add reference for V rotation * cont : refactor * cont : support context shift * cont : consolidate * cont : dedup + allow different types for the rotation matrix * cont : add env variable to disable rotation * cont : simplify attn rot kv cache logic + rename env * cont : pre-compute the Hadamard matrices
This commit is contained in:
parent
0356e33aaf
commit
744c0c7310
|
|
@ -19,7 +19,7 @@
|
|||
|
||||
// dedup helpers
|
||||
|
||||
static ggml_tensor * build_kq_mask(
|
||||
static ggml_tensor * build_attn_inp_kq_mask(
|
||||
ggml_context * ctx,
|
||||
const llama_kv_cache_context * mctx,
|
||||
const llama_ubatch & ubatch,
|
||||
|
|
@ -28,7 +28,11 @@ static ggml_tensor * build_kq_mask(
|
|||
const auto n_tokens = ubatch.n_tokens;
|
||||
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
||||
|
||||
return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
ggml_tensor * res = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
ggml_set_input(res);
|
||||
ggml_set_name(res, "attn_inp_kq_mask");
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
static bool can_reuse_kq_mask(
|
||||
|
|
@ -52,6 +56,21 @@ static bool can_reuse_kq_mask(
|
|||
|
||||
// impl
|
||||
|
||||
static ggml_tensor * ggml_mul_mat_aux(
|
||||
ggml_context * ctx,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * rot) {
|
||||
const auto n = rot->ne[0];
|
||||
|
||||
ggml_tensor * res;
|
||||
|
||||
res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n);
|
||||
res = ggml_mul_mat (ctx, rot, res);
|
||||
res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
||||
if (ubatch->token) {
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
|
|
@ -429,6 +448,14 @@ void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
|
|||
mctx->set_input_v_idxs(self_v_idxs, ubatch);
|
||||
|
||||
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
|
||||
if (self_k_rot) {
|
||||
mctx->set_input_k_rot(self_k_rot);
|
||||
}
|
||||
|
||||
if (self_v_rot) {
|
||||
mctx->set_input_v_rot(self_v_rot);
|
||||
}
|
||||
}
|
||||
|
||||
bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
|
||||
|
|
@ -476,6 +503,14 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
|
|||
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
|
||||
|
||||
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||
|
||||
if (self_k_rot) {
|
||||
mctx->get_base()->set_input_k_rot(self_k_rot);
|
||||
}
|
||||
|
||||
if (self_v_rot) {
|
||||
mctx->get_base()->set_input_v_rot(self_v_rot);
|
||||
}
|
||||
}
|
||||
|
||||
bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
|
||||
|
|
@ -532,6 +567,14 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
|||
|
||||
mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
|
||||
|
||||
if (inp_attn->self_k_rot) {
|
||||
mctx->get_attn()->set_input_k_rot(inp_attn->self_k_rot);
|
||||
}
|
||||
|
||||
if (inp_attn->self_v_rot) {
|
||||
mctx->get_attn()->set_input_v_rot(inp_attn->self_v_rot);
|
||||
}
|
||||
|
||||
const int64_t n_rs = mctx->get_recr()->get_n_rs();
|
||||
|
||||
if (inp_rs->s_copy) {
|
||||
|
|
@ -630,6 +673,14 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
|
|||
attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
|
||||
}
|
||||
|
||||
if (inp_attn->self_k_rot) {
|
||||
attn_ctx->get_base()->set_input_k_rot(inp_attn->self_k_rot);
|
||||
}
|
||||
|
||||
if (inp_attn->self_v_rot) {
|
||||
attn_ctx->get_base()->set_input_v_rot(inp_attn->self_v_rot);
|
||||
}
|
||||
|
||||
const int64_t n_rs = mctx->get_recr()->get_n_rs();
|
||||
|
||||
if (inp_rs->s_copy) {
|
||||
|
|
@ -2002,13 +2053,13 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
|
|||
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
|
||||
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams);
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
}
|
||||
|
||||
inp->self_k_rot = mctx_cur->build_input_k_rot(ctx0);
|
||||
inp->self_v_rot = mctx_cur->build_input_v_rot(ctx0);
|
||||
|
||||
return inp;
|
||||
}
|
||||
|
||||
|
|
@ -2034,6 +2085,15 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
int il) const {
|
||||
GGML_ASSERT(v_mla == nullptr);
|
||||
|
||||
if (inp->self_k_rot) {
|
||||
q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot);
|
||||
k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot);
|
||||
}
|
||||
|
||||
if (inp->self_v_rot) {
|
||||
v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot);
|
||||
}
|
||||
|
||||
// these nodes are added to the graph together so that they are not reordered
|
||||
// by doing so, the number of splits in the graph is reduced
|
||||
// expand k later to enable rope fusion which directly writes into k-v cache
|
||||
|
|
@ -2061,6 +2121,10 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
if (inp->self_v_rot) {
|
||||
cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot);
|
||||
}
|
||||
|
||||
if (wo) {
|
||||
cur = build_lora_mm(wo, cur);
|
||||
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) {
|
||||
|
|
@ -2090,9 +2154,7 @@ static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl(
|
|||
|
||||
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams);
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
}
|
||||
|
||||
|
|
@ -2171,6 +2233,18 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
ggml_tensor * v_mla,
|
||||
float kq_scale,
|
||||
int il) const {
|
||||
if (inp->self_k_rot) {
|
||||
q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot);
|
||||
if (k_cur) {
|
||||
k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot);
|
||||
}
|
||||
}
|
||||
if (inp->self_v_rot) {
|
||||
if (v_cur) {
|
||||
v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot);
|
||||
}
|
||||
}
|
||||
|
||||
// these nodes are added to the graph together so that they are not reordered
|
||||
// by doing so, the number of splits in the graph is reduced
|
||||
ggml_build_forward_expand(gf, q_cur);
|
||||
|
|
@ -2211,6 +2285,10 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
if (inp->self_v_rot) {
|
||||
cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot);
|
||||
}
|
||||
|
||||
if (wo) {
|
||||
cur = build_lora_mm(wo, cur);
|
||||
}
|
||||
|
|
@ -2293,12 +2371,8 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
|
|||
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
ggml_set_name(inp->self_kq_mask, "self_kq_mask");
|
||||
|
||||
inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams);
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv");
|
||||
}
|
||||
|
||||
{
|
||||
|
|
@ -2307,14 +2381,13 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
|
|||
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask_swa = build_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams);
|
||||
ggml_set_input(inp->self_kq_mask_swa);
|
||||
ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa");
|
||||
|
||||
inp->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams);
|
||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||
ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv");
|
||||
}
|
||||
|
||||
inp->self_k_rot = mctx_cur->get_base()->build_input_k_rot(ctx0);
|
||||
inp->self_v_rot = mctx_cur->get_base()->build_input_v_rot(ctx0);
|
||||
|
||||
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
|
|
@ -2473,9 +2546,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa()
|
|||
inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp_attn->self_kq_mask = build_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams);
|
||||
ggml_set_input(inp_attn->self_kq_mask);
|
||||
|
||||
inp_attn->self_kq_mask = build_attn_inp_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams);
|
||||
inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask;
|
||||
}
|
||||
|
||||
|
|
@ -2483,9 +2554,7 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa()
|
|||
inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp_attn->self_kq_mask_swa = build_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams);
|
||||
ggml_set_input(inp_attn->self_kq_mask_swa);
|
||||
|
||||
inp_attn->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams);
|
||||
inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -308,6 +308,10 @@ public:
|
|||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
|
||||
// note: assumes v_rot^ == I
|
||||
ggml_tensor * self_k_rot = nullptr;
|
||||
ggml_tensor * self_v_rot = nullptr;
|
||||
|
||||
// note: these have to be copies because in order to be able to reuse a graph, its inputs
|
||||
// need to carry these parameters with them. otherwise, they can point to freed
|
||||
// llm_graph_params from a previous batch, causing stack-use-after-return
|
||||
|
|
@ -384,6 +388,10 @@ public:
|
|||
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
|
||||
// note: using same rotation matrices for both base and swa cache
|
||||
ggml_tensor * self_k_rot = nullptr;
|
||||
ggml_tensor * self_v_rot = nullptr;
|
||||
|
||||
const llama_hparams hparams;
|
||||
const llama_cparams cparams;
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,65 @@
|
|||
#include <map>
|
||||
#include <stdexcept>
|
||||
|
||||
static bool ggml_is_power_of_2(int n) {
|
||||
return (n & (n - 1)) == 0;
|
||||
}
|
||||
|
||||
// orthonormal Walsh-Hadamard rotation matrix
|
||||
// note: res^2 == I
|
||||
static void ggml_gen_hadamard(ggml_tensor * tensor) {
|
||||
assert(tensor->type == GGML_TYPE_F32);
|
||||
|
||||
const int n = tensor->ne[0];
|
||||
|
||||
assert(ggml_is_power_of_2(n));
|
||||
assert(tensor->ne[1] == n);
|
||||
assert(tensor->ne[2] == 1);
|
||||
assert(tensor->ne[3] == 1);
|
||||
|
||||
std::vector<float> data_f32;
|
||||
|
||||
float * data = (float *) tensor->data;
|
||||
|
||||
if (tensor->type != GGML_TYPE_F32) {
|
||||
data_f32.resize(n*n);
|
||||
data = data_f32.data();
|
||||
}
|
||||
|
||||
data[0*n + 0] = 1.0 / sqrtf(n);
|
||||
|
||||
for (int s = 1; s < n; s *= 2) {
|
||||
for (int i = 0; i < s; i++) {
|
||||
for (int j = 0; j < s; j++) {
|
||||
const float val = data[i*n + j];
|
||||
|
||||
data[(i + s)*n + (j )] = val;
|
||||
data[(i )*n + (j + s)] = val;
|
||||
data[(i + s)*n + (j + s)] = -val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (tensor->type != GGML_TYPE_F32) {
|
||||
ggml_quantize_chunk(tensor->type, data, tensor->data, 0, 1, n*n, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
static ggml_tensor * ggml_mul_mat_aux(
|
||||
ggml_context * ctx,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * rot) {
|
||||
const auto n = rot->ne[0];
|
||||
|
||||
ggml_tensor * res;
|
||||
|
||||
res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n);
|
||||
res = ggml_mul_mat (ctx, rot, res);
|
||||
res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
//
|
||||
// llama_kv_cache
|
||||
//
|
||||
|
|
@ -209,6 +268,48 @@ llama_kv_cache::llama_kv_cache(
|
|||
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
||||
}
|
||||
|
||||
const char * LLAMA_ATTN_ROT_DISABLE = getenv("LLAMA_ATTN_ROT_DISABLE");
|
||||
const bool attn_rot_disable = LLAMA_ATTN_ROT_DISABLE ? atoi(LLAMA_ATTN_ROT_DISABLE) : false;
|
||||
if (attn_rot_disable) {
|
||||
LLAMA_LOG_WARN("%s: attention rotation force disabled (LLAMA_ATTN_ROT_DISABLE)\n", __func__);
|
||||
}
|
||||
|
||||
attn_rot_k =
|
||||
!attn_rot_disable &&
|
||||
ggml_is_quantized(type_k) &&
|
||||
!hparams.is_n_embd_k_gqa_variable() &&
|
||||
hparams.n_embd_head_k() % 64 == 0;
|
||||
|
||||
attn_rot_v =
|
||||
!attn_rot_disable &&
|
||||
ggml_is_quantized(type_v) &&
|
||||
!hparams.is_n_embd_v_gqa_variable() &&
|
||||
hparams.n_embd_head_v() % 64 == 0;
|
||||
|
||||
LLAMA_LOG_INFO("%s: attn_rot_k = %d\n", __func__, attn_rot_k);
|
||||
LLAMA_LOG_INFO("%s: attn_rot_v = %d\n", __func__, attn_rot_v);
|
||||
|
||||
// pre-compute the haramard matrices and keep them in host memory
|
||||
// TODO: in the future, we can make copies in the backend buffers to avoid host -> device transfers
|
||||
if (attn_rot_k || attn_rot_v) {
|
||||
for (int64_t n = 64; n <= std::max(hparams.n_embd_head_k(), hparams.n_embd_head_v()); n *= 2) {
|
||||
attn_rot_hadamard[n] = std::vector<float>(n*n);
|
||||
|
||||
ggml_init_params params = {
|
||||
/* .mem_size = */ 1*ggml_tensor_overhead(),
|
||||
/* .mem_buffer = */ nullptr,
|
||||
/* .no_alloc = */ true,
|
||||
};
|
||||
|
||||
ggml_context_ptr ctx { ggml_init(params) };
|
||||
|
||||
ggml_tensor * tmp = ggml_new_tensor_2d(ctx.get(), GGML_TYPE_F32, n, n);
|
||||
tmp->data = attn_rot_hadamard[n].data();
|
||||
|
||||
ggml_gen_hadamard(tmp);
|
||||
}
|
||||
}
|
||||
|
||||
const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
|
||||
debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
|
||||
}
|
||||
|
|
@ -1004,6 +1105,14 @@ bool llama_kv_cache::get_has_shift() const {
|
|||
return result;
|
||||
}
|
||||
|
||||
ggml_type llama_kv_cache::type_k() const {
|
||||
return layers[0].k->type;
|
||||
}
|
||||
|
||||
ggml_type llama_kv_cache::type_v() const {
|
||||
return layers[0].v->type;
|
||||
}
|
||||
|
||||
uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const {
|
||||
uint32_t result = 0;
|
||||
|
||||
|
|
@ -1189,6 +1298,47 @@ ggml_tensor * llama_kv_cache::build_input_v_idxs(ggml_context * ctx, const llama
|
|||
return v_idxs;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache::build_input_k_rot(ggml_context * ctx) const {
|
||||
ggml_tensor * res = nullptr;
|
||||
|
||||
if (attn_rot_k) {
|
||||
int nrot = 64;
|
||||
|
||||
// TODO: investigate if using the smallest rotation matrix is beneficial also for K (similar as for V)
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/21038#issuecomment-4141323088
|
||||
do {
|
||||
nrot *= 2;
|
||||
} while (hparams.n_embd_head_k() % nrot == 0);
|
||||
nrot /= 2;
|
||||
|
||||
res = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nrot, nrot);
|
||||
ggml_set_input(res);
|
||||
ggml_set_name(res, "attn_inp_k_rot");
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache::build_input_v_rot(ggml_context * ctx) const {
|
||||
ggml_tensor * res = nullptr;
|
||||
|
||||
if (attn_rot_v) {
|
||||
int nrot = 64;
|
||||
// using smaller rotation matrices for V seems beneficial
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/21038#issuecomment-4146397570
|
||||
//do {
|
||||
// nrot *= 2;
|
||||
//} while (hparams.n_embd_head_v() % nrot == 0);
|
||||
//nrot /= 2;
|
||||
|
||||
res = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nrot, nrot);
|
||||
ggml_set_input(res);
|
||||
ggml_set_name(res, "attn_inp_v_rot");
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
|
||||
const uint32_t n_tokens = ubatch->n_tokens;
|
||||
GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream());
|
||||
|
|
@ -1507,6 +1657,24 @@ void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch
|
|||
}
|
||||
}
|
||||
|
||||
void llama_kv_cache::set_input_k_rot(ggml_tensor * dst) const {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
||||
|
||||
const auto n_rot = dst->ne[0];
|
||||
GGML_ASSERT(attn_rot_hadamard.count(dst->ne[0]));
|
||||
|
||||
memcpy(dst->data, attn_rot_hadamard.at(n_rot).data(), ggml_nbytes(dst));
|
||||
}
|
||||
|
||||
void llama_kv_cache::set_input_v_rot(ggml_tensor * dst) const {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
||||
|
||||
const auto n_rot = dst->ne[0];
|
||||
GGML_ASSERT(attn_rot_hadamard.count(dst->ne[0]));
|
||||
|
||||
memcpy(dst->data, attn_rot_hadamard.at(n_rot).data(), ggml_nbytes(dst));
|
||||
}
|
||||
|
||||
size_t llama_kv_cache::total_size() const {
|
||||
size_t size = 0;
|
||||
|
||||
|
|
@ -1542,6 +1710,7 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
|
|||
ggml_context * ctx,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * shift,
|
||||
ggml_tensor * rot,
|
||||
ggml_tensor * factors,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
|
|
@ -1567,10 +1736,16 @@ ggml_tensor * llama_kv_cache::build_rope_shift(
|
|||
// dequantize to f32 -> RoPE -> quantize back
|
||||
tmp = ggml_cast(ctx, cur, GGML_TYPE_F32);
|
||||
|
||||
// rotate back
|
||||
tmp = ggml_mul_mat_aux(ctx, tmp, rot);
|
||||
|
||||
tmp = ggml_rope_ext(ctx, tmp,
|
||||
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
|
||||
|
||||
// rotate fwd
|
||||
tmp = ggml_mul_mat_aux(ctx, tmp, rot);
|
||||
|
||||
tmp = ggml_cpy(ctx, tmp, cur);
|
||||
} else {
|
||||
// we rotate only the first n_rot dimensions
|
||||
|
|
@ -1591,6 +1766,9 @@ public:
|
|||
|
||||
ggml_tensor * k_shift; // I32 [kv_size*n_stream]
|
||||
|
||||
// note: assumes k_rot^2 == I
|
||||
ggml_tensor * k_rot = nullptr;
|
||||
|
||||
const llama_kv_cache * kv_self;
|
||||
};
|
||||
|
||||
|
|
@ -1600,6 +1778,10 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
|
|||
if (k_shift) {
|
||||
kv_self->set_input_k_shift(k_shift);
|
||||
}
|
||||
|
||||
if (k_rot) {
|
||||
kv_self->set_input_k_rot(k_rot);
|
||||
}
|
||||
}
|
||||
|
||||
ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
|
||||
|
|
@ -1611,6 +1793,8 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
|
|||
inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
|
||||
ggml_set_input(inp->k_shift);
|
||||
|
||||
inp->k_rot = build_input_k_rot(ctx);
|
||||
|
||||
const auto & cparams = lctx->get_cparams();
|
||||
|
||||
for (const auto & layer : layers) {
|
||||
|
|
@ -1635,7 +1819,7 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
|
|||
ggml_row_size(layer.k->type, n_embd_k_gqa),
|
||||
ggml_row_size(layer.k->type, n_embd_nope));
|
||||
|
||||
ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, il);
|
||||
ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, inp->k_rot, rope_factors, freq_base_l, freq_scale_l, il);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
|
@ -2239,6 +2423,14 @@ uint32_t llama_kv_cache_context::get_n_kv() const {
|
|||
return n_kv;
|
||||
}
|
||||
|
||||
ggml_type llama_kv_cache_context::type_k() const {
|
||||
return kv->type_k();
|
||||
}
|
||||
|
||||
ggml_type llama_kv_cache_context::type_v() const {
|
||||
return kv->type_v();
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_context::get_k(ggml_context * ctx, int32_t il) const {
|
||||
return kv->get_k(ctx, il, n_kv, sinfos[i_cur]);
|
||||
}
|
||||
|
|
@ -2263,6 +2455,14 @@ ggml_tensor * llama_kv_cache_context::build_input_v_idxs(ggml_context * ctx, con
|
|||
return kv->build_input_v_idxs(ctx, ubatch);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_context::build_input_k_rot(ggml_context * ctx) const {
|
||||
return kv->build_input_k_rot(ctx);
|
||||
}
|
||||
|
||||
ggml_tensor * llama_kv_cache_context::build_input_v_rot(ggml_context * ctx) const {
|
||||
return kv->build_input_v_rot(ctx);
|
||||
}
|
||||
|
||||
void llama_kv_cache_context::set_input_k_shift(ggml_tensor * dst) const {
|
||||
kv->set_input_k_shift(dst);
|
||||
}
|
||||
|
|
@ -2282,3 +2482,11 @@ void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
|||
void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
||||
kv->set_input_pos_bucket(dst, ubatch);
|
||||
}
|
||||
|
||||
void llama_kv_cache_context::set_input_k_rot(ggml_tensor * dst) const {
|
||||
kv->set_input_k_rot(dst);
|
||||
}
|
||||
|
||||
void llama_kv_cache_context::set_input_v_rot(ggml_tensor * dst) const {
|
||||
kv->set_input_v_rot(dst);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -152,6 +152,9 @@ public:
|
|||
|
||||
bool get_has_shift() const;
|
||||
|
||||
ggml_type type_k() const;
|
||||
ggml_type type_v() const;
|
||||
|
||||
//
|
||||
// graph_build API
|
||||
//
|
||||
|
|
@ -191,6 +194,9 @@ public:
|
|||
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||
|
||||
ggml_tensor * build_input_k_rot(ggml_context * ctx) const;
|
||||
ggml_tensor * build_input_v_rot(ggml_context * ctx) const;
|
||||
|
||||
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
|
||||
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
|
||||
|
||||
|
|
@ -199,6 +205,9 @@ public:
|
|||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
|
||||
void set_input_k_rot(ggml_tensor * dst) const;
|
||||
void set_input_v_rot(ggml_tensor * dst) const;
|
||||
|
||||
private:
|
||||
const llama_model & model;
|
||||
const llama_hparams & hparams;
|
||||
|
|
@ -226,6 +235,13 @@ private:
|
|||
// SWA
|
||||
const uint32_t n_swa = 0;
|
||||
|
||||
// env: LLAMA_ATTN_ROT_DISABLE
|
||||
bool attn_rot_k = false;
|
||||
bool attn_rot_v = false;
|
||||
|
||||
// pre-computed hadamard martrices
|
||||
std::unordered_map<int64_t, std::vector<float>> attn_rot_hadamard;
|
||||
|
||||
// env: LLAMA_KV_CACHE_DEBUG
|
||||
int debug = 0;
|
||||
|
||||
|
|
@ -262,6 +278,7 @@ private:
|
|||
ggml_context * ctx,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * shift,
|
||||
ggml_tensor * rot,
|
||||
ggml_tensor * factors,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
|
|
@ -328,6 +345,9 @@ public:
|
|||
|
||||
uint32_t get_n_kv() const;
|
||||
|
||||
ggml_type type_k() const;
|
||||
ggml_type type_v() const;
|
||||
|
||||
// get views of the current state of the cache
|
||||
ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
|
||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
|
||||
|
|
@ -347,6 +367,9 @@ public:
|
|||
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
||||
|
||||
ggml_tensor * build_input_k_rot(ggml_context * ctx) const;
|
||||
ggml_tensor * build_input_v_rot(ggml_context * ctx) const;
|
||||
|
||||
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
|
||||
|
|
@ -354,6 +377,9 @@ public:
|
|||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
|
||||
void set_input_k_rot(ggml_tensor * dst) const;
|
||||
void set_input_v_rot(ggml_tensor * dst) const;
|
||||
|
||||
private:
|
||||
llama_memory_status status;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue