kv-cache : support attention rotation for heterogeneous iSWA (#21513)
* kv-cache : support attention rotation for heterogeneous iSWA * cont : remove assert
This commit is contained in:
parent
957d717ce5
commit
4eb19514dd
|
|
@ -511,6 +511,14 @@ void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
|
|||
if (self_v_rot) {
|
||||
mctx->get_base()->set_input_v_rot(self_v_rot);
|
||||
}
|
||||
|
||||
if (self_k_rot_swa) {
|
||||
mctx->get_swa()->set_input_k_rot(self_k_rot_swa);
|
||||
}
|
||||
|
||||
if (self_v_rot_swa) {
|
||||
mctx->get_swa()->set_input_v_rot(self_v_rot_swa);
|
||||
}
|
||||
}
|
||||
|
||||
bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
|
||||
|
|
@ -681,6 +689,14 @@ void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
|
|||
attn_ctx->get_base()->set_input_v_rot(inp_attn->self_v_rot);
|
||||
}
|
||||
|
||||
if (inp_attn->self_k_rot_swa) {
|
||||
attn_ctx->get_swa()->set_input_k_rot(inp_attn->self_k_rot_swa);
|
||||
}
|
||||
|
||||
if (inp_attn->self_v_rot_swa) {
|
||||
attn_ctx->get_swa()->set_input_v_rot(inp_attn->self_v_rot_swa);
|
||||
}
|
||||
|
||||
const int64_t n_rs = mctx->get_recr()->get_n_rs();
|
||||
|
||||
if (inp_rs->s_copy) {
|
||||
|
|
@ -2233,15 +2249,20 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
ggml_tensor * v_mla,
|
||||
float kq_scale,
|
||||
int il) const {
|
||||
if (inp->self_k_rot) {
|
||||
q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot);
|
||||
const bool is_swa = hparams.is_swa(il);
|
||||
|
||||
auto * k_rot = is_swa ? inp->self_k_rot_swa : inp->self_k_rot;
|
||||
auto * v_rot = is_swa ? inp->self_v_rot_swa : inp->self_v_rot;
|
||||
|
||||
if (k_rot) {
|
||||
q_cur = ggml_mul_mat_aux(ctx0, q_cur, k_rot);
|
||||
if (k_cur) {
|
||||
k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot);
|
||||
k_cur = ggml_mul_mat_aux(ctx0, k_cur, k_rot);
|
||||
}
|
||||
}
|
||||
if (inp->self_v_rot) {
|
||||
if (v_rot) {
|
||||
if (v_cur) {
|
||||
v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot);
|
||||
v_cur = ggml_mul_mat_aux(ctx0, v_cur, v_rot);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2259,8 +2280,6 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
|
||||
const auto * mctx_iswa = inp->mctx;
|
||||
|
||||
const bool is_swa = hparams.is_swa(il);
|
||||
|
||||
const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
|
||||
|
||||
// optionally store to KV cache
|
||||
|
|
@ -2285,8 +2304,8 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
if (inp->self_v_rot) {
|
||||
cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot);
|
||||
if (v_rot) {
|
||||
cur = ggml_mul_mat_aux(ctx0, cur, v_rot);
|
||||
}
|
||||
|
||||
if (wo) {
|
||||
|
|
@ -2388,6 +2407,9 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
|
|||
inp->self_k_rot = mctx_cur->get_base()->build_input_k_rot(ctx0);
|
||||
inp->self_v_rot = mctx_cur->get_base()->build_input_v_rot(ctx0);
|
||||
|
||||
inp->self_k_rot_swa = mctx_cur->get_swa()->build_input_k_rot(ctx0);
|
||||
inp->self_v_rot_swa = mctx_cur->get_swa()->build_input_v_rot(ctx0);
|
||||
|
||||
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -308,7 +308,7 @@ public:
|
|||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
|
||||
// note: assumes v_rot^ == I
|
||||
// note: assumes v_rot^2 == I
|
||||
ggml_tensor * self_k_rot = nullptr;
|
||||
ggml_tensor * self_v_rot = nullptr;
|
||||
|
||||
|
|
@ -388,10 +388,12 @@ public:
|
|||
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
|
||||
// note: using same rotation matrices for both base and swa cache
|
||||
ggml_tensor * self_k_rot = nullptr;
|
||||
ggml_tensor * self_v_rot = nullptr;
|
||||
|
||||
ggml_tensor * self_k_rot_swa = nullptr;
|
||||
ggml_tensor * self_v_rot_swa = nullptr;
|
||||
|
||||
const llama_hparams hparams;
|
||||
const llama_cparams cparams;
|
||||
|
||||
|
|
|
|||
|
|
@ -169,6 +169,18 @@ llama_kv_cache::llama_kv_cache(
|
|||
continue;
|
||||
}
|
||||
|
||||
if (n_embd_head_k_all == 0) {
|
||||
n_embd_head_k_all = (int32_t) hparams.n_embd_head_k(il);
|
||||
} else if (n_embd_head_k_all > 0 && n_embd_head_k_all != (int32_t) hparams.n_embd_head_k(il)) {
|
||||
n_embd_head_k_all = -1;
|
||||
}
|
||||
|
||||
if (n_embd_head_v_all == 0) {
|
||||
n_embd_head_v_all = (int32_t) hparams.n_embd_head_v(il);
|
||||
} else if (n_embd_head_v_all > 0 && n_embd_head_v_all != (int32_t) hparams.n_embd_head_v(il)) {
|
||||
n_embd_head_v_all = -1;
|
||||
}
|
||||
|
||||
// [TAG_V_CACHE_VARIABLE]
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||
const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max();
|
||||
|
|
@ -276,23 +288,23 @@ llama_kv_cache::llama_kv_cache(
|
|||
|
||||
attn_rot_k =
|
||||
!attn_rot_disable &&
|
||||
n_embd_head_k_all > 0 &&
|
||||
ggml_is_quantized(type_k) &&
|
||||
!hparams.is_n_embd_k_gqa_variable() &&
|
||||
hparams.n_embd_head_k() % 64 == 0;
|
||||
|
||||
attn_rot_v =
|
||||
!attn_rot_disable &&
|
||||
n_embd_head_v_all > 0 &&
|
||||
ggml_is_quantized(type_v) &&
|
||||
!hparams.is_n_embd_v_gqa_variable() &&
|
||||
hparams.n_embd_head_v() % 64 == 0;
|
||||
|
||||
LLAMA_LOG_INFO("%s: attn_rot_k = %d\n", __func__, attn_rot_k);
|
||||
LLAMA_LOG_INFO("%s: attn_rot_v = %d\n", __func__, attn_rot_v);
|
||||
LLAMA_LOG_INFO("%s: attn_rot_k = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_k, n_embd_head_k_all);
|
||||
LLAMA_LOG_INFO("%s: attn_rot_v = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_v, n_embd_head_v_all);
|
||||
|
||||
// pre-compute the haramard matrices and keep them in host memory
|
||||
// TODO: in the future, we can make copies in the backend buffers to avoid host -> device transfers
|
||||
if (attn_rot_k || attn_rot_v) {
|
||||
for (int64_t n = 64; n <= std::max(hparams.n_embd_head_k(), hparams.n_embd_head_v()); n *= 2) {
|
||||
for (int64_t n = 64; n <= std::max(n_embd_head_k_all, n_embd_head_v_all); n *= 2) {
|
||||
attn_rot_hadamard[n] = std::vector<float>(n*n);
|
||||
|
||||
ggml_init_params params = {
|
||||
|
|
@ -1308,7 +1320,7 @@ ggml_tensor * llama_kv_cache::build_input_k_rot(ggml_context * ctx) const {
|
|||
// ref: https://github.com/ggml-org/llama.cpp/pull/21038#issuecomment-4141323088
|
||||
do {
|
||||
nrot *= 2;
|
||||
} while (hparams.n_embd_head_k() % nrot == 0);
|
||||
} while (n_embd_head_k_all % nrot == 0);
|
||||
nrot /= 2;
|
||||
|
||||
res = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nrot, nrot);
|
||||
|
|
|
|||
|
|
@ -239,6 +239,11 @@ private:
|
|||
bool attn_rot_k = false;
|
||||
bool attn_rot_v = false;
|
||||
|
||||
// if all layers participating in the cache have constant head size, the value is stored here
|
||||
// otherwise the value is -1
|
||||
int32_t n_embd_head_k_all = 0;
|
||||
int32_t n_embd_head_v_all = 0;
|
||||
|
||||
// pre-computed hadamard martrices
|
||||
std::unordered_map<int64_t, std::vector<float>> attn_rot_hadamard;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue