feat: Overhaul build_recurrent_state / build_inp_s_copy to match attention pattern

https://github.com/ggml-org/llama.cpp/pull/13979/files#r2141701738

This is a big overhaul to bring consistency between how inputs and per-
layer components are created for attention layers and recurrent layers. The
main changes are:

- Rename class llm_graph_input_s_copy -> llm_graph_input_rs
- Add a corresponding llm_graph_input_rs_hybrid_recurrent
- Rename build_inp_s_copy -> build_rs_inp_recurrent
- Add a corresponding build_rs_inp_hybrid_recurrent
- Rename build_recurrent_state -> build_rs to match build_attn w/
llm_graph_input_rs android-build AUTHORS bamba-9b-2.2T.gguf bamba-9b-2.2T.q4_k_m.gguf broken.log build build-rel build-xcframework.sh build.android build.android.bak ci cmake CMakeLists.txt CMakePresets.json CODEOWNERS common common.o CONTRIBUTING.md convert_hf_to_gguf_update.py convert_hf_to_gguf.py convert_llama_ggml_to_gguf.py convert_lora_to_gguf.py debug.log docs examples flake.lock flake.nix ggml ggml-alloc.o ggml-backend.o ggml-metal.o ggml-model-BF16.gguf ggml-model-Q4_K_M.gguf ggml-quants.o ggml.o gguf-py grammar-parser.o grammars include LICENSE licenses llama.log llama.o llamacpp_trace.log main.log Makefile media models mypy.ini pocs poetry.lock prompts pyproject.toml pyrightconfig.json q4_k_m_boot.log q8_0_boot.log quant.log quant2.log README.md requirements requirements.txt sampling.o scripts SECURITY.md src test-grammar-output.tmp test-json-schema-input.tmp tests tools vendor working.log as the first input
- Add a corresponding overload of build_rs w/
llm_graph_input_rs_hybrid_recurrent android-build AUTHORS bamba-9b-2.2T.gguf bamba-9b-2.2T.q4_k_m.gguf broken.log build build-rel build-xcframework.sh build.android build.android.bak ci cmake CMakeLists.txt CMakePresets.json CODEOWNERS common common.o CONTRIBUTING.md convert_hf_to_gguf_update.py convert_hf_to_gguf.py convert_llama_ggml_to_gguf.py convert_lora_to_gguf.py debug.log docs examples flake.lock flake.nix ggml ggml-alloc.o ggml-backend.o ggml-metal.o ggml-model-BF16.gguf ggml-model-Q4_K_M.gguf ggml-quants.o ggml.o gguf-py grammar-parser.o grammars include LICENSE licenses llama.log llama.o llamacpp_trace.log main.log Makefile media models mypy.ini pocs poetry.lock prompts pyproject.toml pyrightconfig.json q4_k_m_boot.log q8_0_boot.log quant.log quant2.log README.md requirements requirements.txt sampling.o scripts SECURITY.md src test-grammar-output.tmp test-json-schema-input.tmp tests tools vendor working.log as the first input
- Add a llm_graph_input_attn_kv_hybrid_recurrent analogous to
llm_graph_input_attn_kv_unified
- Add a build_attn override that takes
llm_graph_input_attn_kv_hybrid_recurrent android-build AUTHORS bamba-9b-2.2T.gguf bamba-9b-2.2T.q4_k_m.gguf broken.log build build-rel build-xcframework.sh build.android build.android.bak ci cmake CMakeLists.txt CMakePresets.json CODEOWNERS common common.o CONTRIBUTING.md convert_hf_to_gguf_update.py convert_hf_to_gguf.py convert_llama_ggml_to_gguf.py convert_lora_to_gguf.py debug.log docs examples flake.lock flake.nix ggml ggml-alloc.o ggml-backend.o ggml-metal.o ggml-model-BF16.gguf ggml-model-Q4_K_M.gguf ggml-quants.o ggml.o gguf-py grammar-parser.o grammars include LICENSE licenses llama.log llama.o llamacpp_trace.log main.log Makefile media models mypy.ini pocs poetry.lock prompts pyproject.toml pyrightconfig.json q4_k_m_boot.log q8_0_boot.log quant.log quant2.log README.md requirements requirements.txt sampling.o scripts SECURITY.md src test-grammar-output.tmp test-json-schema-input.tmp tests tools vendor working.log as the first input

This makes the two paradigms fully consistent. The main drawback is the
code duplication in the build_attn and build_rs implementations where the
only difference between implementations is how they cast the memory state.

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
Gabe Goodhart 2025-06-12 17:04:27 -06:00
parent 4ec4e6a801
commit 11cd80d5de
3 changed files with 241 additions and 101 deletions

View File

@ -239,7 +239,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
}
}
void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
GGML_UNUSED(ubatch);
const int64_t n_kv = kv_state->get_n_kv();
@ -255,6 +255,11 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
}
}
llm_graph_input_rs_hybrid_recurrent::llm_graph_input_rs_hybrid_recurrent(
const llama_kv_cache_hybrid_recurrent_state * kv_state) :
llm_graph_input_rs(kv_state->get_state_recurrent()) {
}
void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
GGML_UNUSED(ubatch);
@ -360,6 +365,13 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
}
}
llm_graph_input_attn_kv_hybrid_recurrent::llm_graph_input_attn_kv_hybrid_recurrent(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_hybrid_recurrent_state * kv_state) :
llm_graph_input_attn_kv_unified(hparams, cparams, kv_state->get_state_attn()) {
}
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
if (self_kq_mask) {
kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
@ -962,25 +974,6 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
return cur;
}
ggml_tensor * llm_graph_context::build_inp_s_copy(const llama_kv_cache_recurrent_state * kv_state) const {
if (kv_state == nullptr) {
kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
}
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
const auto n_kv = kv_state->get_n_kv();
auto & cur = inp->s_copy;
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
ggml_set_input(cur);
res->add_input(std::move(inp));
return cur;
}
ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
@ -1262,9 +1255,7 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_build_forward_expand(gf, k_cur);
ggml_build_forward_expand(gf, v_cur);
// NOTE: For hybrid caches, this may be a child of mstate, so we use the one
// encapsulated in inp
const auto * kv_state = inp->kv_state;
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
// store to KV cache
{
@ -1296,15 +1287,14 @@ ggml_tensor * llm_graph_context::build_attn(
return cur;
}
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const {
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate);
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state->get_state_attn());
llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const {
auto inp = std::make_unique<llm_graph_input_attn_kv_hybrid_recurrent>(
hparams, cparams, static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate));
{
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
const auto n_kv = kv_state->get_state_attn()->get_n_kv();
const auto n_kv = inp->kv_state->get_n_kv();
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
//cb(inp->self_kq_mask, "KQ_mask", -1);
@ -1313,7 +1303,57 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_hybrid_re
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
}
return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
return (llm_graph_input_attn_kv_hybrid_recurrent *) res->add_input(std::move(inp));
}
ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_kv_hybrid_recurrent * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
ggml_tensor * v_mla,
float kq_scale,
int il) const {
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
ggml_build_forward_expand(gf, q_cur);
ggml_build_forward_expand(gf, k_cur);
ggml_build_forward_expand(gf, v_cur);
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate)->get_state_attn();
// store to KV cache
{
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
}
const auto & kq_mask = inp->get_kq_mask();
ggml_tensor * q = q_cur;
ggml_tensor * k = kv_state->get_k(ctx0, il);
ggml_tensor * v = kv_state->get_v(ctx0, il);
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
cb(cur, "kqv_out", il);
if (wo) {
cur = build_lora_mm(wo, cur);
if (arch == LLM_ARCH_GLM4) {
// GLM4 seems to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
}
if (wo_b) {
cur = ggml_add(ctx0, cur, wo_b);
}
return cur;
}
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
@ -1455,18 +1495,30 @@ ggml_tensor * llm_graph_context::build_attn(
return cur;
}
ggml_tensor * llm_graph_context::build_recurrent_state(
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
int32_t state_size,
int32_t n_seqs,
bool avoid_copies,
const llama_kv_cache_recurrent_state * kv_state) const {
llm_graph_input_rs * llm_graph_context::build_rs_inp_recurrent() const {
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
if (kv_state == nullptr) {
kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
}
auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
const auto n_kv = kv_state->get_n_kv();
auto & cur = inp->s_copy;
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
ggml_set_input(cur);
return (llm_graph_input_rs *) res->add_input(std::move(inp));
}
ggml_tensor * llm_graph_context::build_rs(
llm_graph_input_rs * inp,
ggml_cgraph * gf,
ggml_tensor * s,
int32_t state_size,
int32_t n_seqs,
bool avoid_copies) const {
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
const auto n_kv = kv_state->get_n_kv();
const auto kv_head = kv_state->get_head();
@ -1485,7 +1537,7 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
// copy states
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
// {state_size, kv_size} -> {state_size, n_seqs}
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0));
ggml_build_forward_expand(gf, output_states);
} else {
// FIXME: make the gathering operation happen before the copy below
@ -1494,7 +1546,66 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
}
// copy extra states which won't be changed further (between n_seqs and n_kv)
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_kv - n_seqs, n_seqs*inp->s_copy->nb[0]));
ggml_build_forward_expand(gf,
ggml_cpy(ctx0,
states_extra,
ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s))));
return output_states;
}
llm_graph_input_rs_hybrid_recurrent * llm_graph_context::build_rs_inp_hybrid_recurrent() const {
auto inp = std::make_unique<llm_graph_input_rs_hybrid_recurrent>(
static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate));
const auto n_kv = inp->kv_state->get_n_kv();
auto & cur = inp->s_copy;
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
ggml_set_input(cur);
return (llm_graph_input_rs_hybrid_recurrent *) res->add_input(std::move(inp));
}
ggml_tensor * llm_graph_context::build_rs(
llm_graph_input_rs_hybrid_recurrent * inp,
ggml_cgraph * gf,
ggml_tensor * s,
int32_t state_size,
int32_t n_seqs,
bool avoid_copies) const {
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate)->get_state_recurrent();
const auto n_kv = kv_state->get_n_kv();
const auto kv_head = kv_state->get_head();
const auto rs_zero = kv_state->get_rs_z();
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size());
// Clear a single state which will then be copied to the other cleared states.
// Note that this is a no-op when the view is zero-sized.
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
ggml_tensor * output_states;
if (!avoid_copies) {
// copy states
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
// {state_size, kv_size} -> {state_size, n_seqs}
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0));
ggml_build_forward_expand(gf, output_states);
} else {
// FIXME: make the gathering operation happen before the copy below
// (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
output_states = states;
}
// copy extra states which won't be changed further (between n_seqs and n_kv)
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_kv - n_seqs, n_seqs*inp->s_copy->nb[0]));
ggml_build_forward_expand(gf,
ggml_cpy(ctx0,
states_extra,
@ -1504,9 +1615,9 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
}
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
ggml_cgraph * gf,
ggml_tensor * state_copy,
const llama_ubatch & ubatch,
llm_graph_input_rs * inp,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
int il) const {
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
@ -1516,8 +1627,8 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
ggml_tensor * token_shift_all = kv_state->get_k_l(il);
ggml_tensor * token_shift = build_recurrent_state(
gf, token_shift_all, state_copy,
ggml_tensor * token_shift = build_rs(
inp, gf, token_shift_all,
hparams.n_embd_k_s(), n_seqs);
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);

View File

@ -189,10 +189,10 @@ public:
const llama_cparams & cparams;
};
class llm_graph_input_s_copy : public llm_graph_input_i {
class llm_graph_input_rs : public llm_graph_input_i {
public:
llm_graph_input_s_copy(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
virtual ~llm_graph_input_s_copy() = default;
llm_graph_input_rs(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
virtual ~llm_graph_input_rs() = default;
void set_input(const llama_ubatch * ubatch) override;
@ -201,6 +201,12 @@ public:
const llama_kv_cache_recurrent_state * kv_state;
};
class llm_graph_input_rs_hybrid_recurrent : public llm_graph_input_rs {
public:
llm_graph_input_rs_hybrid_recurrent(const llama_kv_cache_hybrid_recurrent_state * kv_state);
virtual ~llm_graph_input_rs_hybrid_recurrent() = default;
};
class llm_graph_input_cross_embd : public llm_graph_input_i {
public:
llm_graph_input_cross_embd(
@ -258,6 +264,15 @@ public:
const llama_kv_cache_unified_state * kv_state;
};
class llm_graph_input_attn_kv_hybrid_recurrent : public llm_graph_input_attn_kv_unified {
public:
llm_graph_input_attn_kv_hybrid_recurrent(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_hybrid_recurrent_state * kv_state);
virtual ~llm_graph_input_attn_kv_hybrid_recurrent() = default;
};
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
public:
llm_graph_input_attn_kv_unified_iswa(
@ -509,7 +524,6 @@ struct llm_graph_context {
ggml_tensor * build_inp_out_ids() const;
ggml_tensor * build_inp_mean() const;
ggml_tensor * build_inp_cls() const;
ggml_tensor * build_inp_s_copy(const llama_kv_cache_recurrent_state * kv_state = nullptr) const;
ggml_tensor * build_inp_cross_embd() const;
ggml_tensor * build_inp_pos_bucket_enc() const;
@ -575,8 +589,6 @@ struct llm_graph_context {
float kq_scale,
int il) const;
llm_graph_input_attn_kv_unified * build_attn_inp_kv_hybrid_recurrent() const;
llm_graph_input_attn_cross * build_attn_inp_cross() const;
ggml_tensor * build_attn(
@ -592,23 +604,48 @@ struct llm_graph_context {
float kq_scale,
int il) const;
llm_graph_input_attn_kv_hybrid_recurrent * build_attn_inp_kv_hybrid_recurrent() const;
ggml_tensor * build_attn(
llm_graph_input_attn_kv_hybrid_recurrent * inp,
ggml_cgraph * gf,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b,
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,
int il) const;
//
// recurrent
//
ggml_tensor * build_recurrent_state(
ggml_cgraph * gf,
ggml_tensor * s,
ggml_tensor * state_copy,
int32_t state_size,
int32_t n_seqs,
bool avoid_copies = false,
const llama_kv_cache_recurrent_state * kv_state = nullptr) const;
llm_graph_input_rs * build_rs_inp_recurrent() const;
ggml_tensor * build_rs(
llm_graph_input_rs * inp,
ggml_cgraph * gf,
ggml_tensor * s,
int32_t state_size,
int32_t n_seqs,
bool avoid_copies = false) const;
llm_graph_input_rs_hybrid_recurrent * build_rs_inp_hybrid_recurrent() const;
ggml_tensor * build_rs(
llm_graph_input_rs_hybrid_recurrent * inp,
ggml_cgraph * gf,
ggml_tensor * s,
int32_t state_size,
int32_t n_seqs,
bool avoid_copies = false) const;
ggml_tensor * build_rwkv_token_shift_load(
ggml_cgraph * gf,
ggml_tensor * state_copy,
const llama_ubatch & ubatch,
llm_graph_input_rs * inp,
ggml_cgraph * gf,
const llama_ubatch & ubatch,
int il) const;
ggml_tensor * build_rwkv_token_shift_store(

View File

@ -9116,7 +9116,7 @@ struct llm_build_mamba : public llm_graph_context {
// {n_embd, n_tokens}
inpL = build_inp_embd(model.tok_embd);
ggml_tensor * state_copy = build_inp_s_copy();
auto * rs_inp = build_rs_inp_recurrent();
for (int il = 0; il < n_layer; ++il) {
// norm
@ -9125,7 +9125,7 @@ struct llm_build_mamba : public llm_graph_context {
LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
cur = build_mamba_layer(gf, cur, state_copy, ubatch, il);
cur = build_mamba_layer(rs_inp, gf, cur, ubatch, il);
if (il == n_layer - 1) {
// skip computing output for unused tokens
@ -9163,11 +9163,11 @@ struct llm_build_mamba : public llm_graph_context {
// TODO: split
ggml_tensor * build_mamba_layer(
ggml_cgraph * gf,
ggml_tensor * cur,
ggml_tensor * state_copy,
const llama_ubatch & ubatch,
int il) const {
llm_graph_input_rs * inp,
ggml_cgraph * gf,
ggml_tensor * cur,
const llama_ubatch & ubatch,
int il) const {
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
const auto kv_head = kv_state->get_head();
@ -9192,12 +9192,12 @@ struct llm_build_mamba : public llm_graph_context {
ggml_tensor * ssm_states_all = kv_state->get_v_l(il);
// (ab)using the KV cache to store the states
ggml_tensor * conv = build_recurrent_state(
gf, conv_states_all, state_copy,
ggml_tensor * conv = build_rs(
inp, gf, conv_states_all,
hparams.n_embd_k_s(), n_seqs);
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
ggml_tensor * ssm = build_recurrent_state(
gf, ssm_states_all, state_copy,
ggml_tensor * ssm = build_rs(
inp, gf, ssm_states_all,
hparams.n_embd_v_s(), n_seqs);
ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs);
@ -11909,10 +11909,10 @@ struct llm_build_rwkv6_base : public llm_graph_context {
}
ggml_tensor * build_rwkv6_time_mix(
llm_graph_input_rs * inp,
ggml_cgraph * gf,
ggml_tensor * cur,
ggml_tensor * x_prev,
ggml_tensor * state_copy,
const llama_ubatch & ubatch,
int il) const {
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
@ -12036,8 +12036,8 @@ struct llm_build_rwkv6_base : public llm_graph_context {
k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w));
}
ggml_tensor * wkv_state = build_recurrent_state(
gf, kv_state->get_v_l(il), state_copy,
ggml_tensor * wkv_state = build_rs(
inp, gf, kv_state->get_v_l(il),
hparams.n_embd_v_s(), n_seqs);
ggml_tensor * wkv_output;
@ -12092,7 +12092,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
inpL = build_inp_embd(model.tok_embd);
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
ggml_tensor * state_copy = build_inp_s_copy();
auto * rs_inp = build_rs_inp_recurrent();
const auto n_embd = hparams.n_embd;
const auto n_seq_tokens = ubatch.n_seq_tokens;
@ -12102,9 +12102,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
const llama_layer * layer = &model.layers[il];
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
ggml_tensor * token_shift = build_rwkv_token_shift_load(
gf, state_copy, ubatch, il
);
ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
@ -12119,7 +12117,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
1
);
cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il);
cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il);
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
@ -12189,7 +12187,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
inpL = build_inp_embd(model.tok_embd);
ggml_tensor * state_copy = build_inp_s_copy();
auto * rs_inp = build_rs_inp_recurrent();
const auto n_embd = hparams.n_embd;
const auto n_seq_tokens = ubatch.n_seq_tokens;
@ -12199,9 +12197,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
const llama_layer * layer = &model.layers[il];
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
ggml_tensor * token_shift = build_rwkv_token_shift_load(
gf, state_copy, ubatch, il
);
ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
cb(att_norm, "attn_norm", il);
@ -12213,7 +12209,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
1
);
cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il);
cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il);
token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
@ -12301,10 +12297,10 @@ struct llm_build_rwkv7_base : public llm_graph_context {
}
ggml_tensor * build_rwkv7_time_mix(
llm_graph_input_rs * inp,
ggml_cgraph * gf,
ggml_tensor * cur,
ggml_tensor * x_prev,
ggml_tensor * state_copy,
ggml_tensor *& first_layer_value,
const llama_ubatch & ubatch,
int il) const {
@ -12387,8 +12383,8 @@ struct llm_build_rwkv7_base : public llm_graph_context {
v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens);
a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
ggml_tensor * wkv_state = build_recurrent_state(
gf, kv_state->get_v_l(il), state_copy,
ggml_tensor * wkv_state = build_rs(
inp, gf, kv_state->get_v_l(il),
hparams.n_embd_v_s(), n_seqs);
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
@ -12445,7 +12441,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
inpL = build_inp_embd(model.tok_embd);
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
ggml_tensor * state_copy = build_inp_s_copy();
auto * rs_inp = build_rs_inp_recurrent();
const auto n_embd = hparams.n_embd;
const auto n_seq_tokens = ubatch.n_seq_tokens;
@ -12455,9 +12451,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
const llama_layer * layer = &model.layers[il];
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
ggml_tensor * token_shift = build_rwkv_token_shift_load(
gf, state_copy, ubatch, il
);
ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
@ -12472,7 +12466,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
1
);
cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il);
cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il);
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
@ -12538,7 +12532,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
inpL = build_inp_embd(model.tok_embd);
ggml_tensor * state_copy = build_inp_s_copy();
auto * rs_inp = build_rs_inp_recurrent();
const auto n_embd = hparams.n_embd;
const auto n_seq_tokens = ubatch.n_seq_tokens;
@ -12548,9 +12542,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
const llama_layer * layer = &model.layers[il];
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
ggml_tensor * token_shift = build_rwkv_token_shift_load(
gf, state_copy, ubatch, il
);
ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
cb(att_norm, "attn_norm", il);
@ -12562,7 +12554,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
1
);
cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il);
cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il);
token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));