feat: Overhaul build_recurrent_state / build_inp_s_copy to match attention pattern
https://github.com/ggml-org/llama.cpp/pull/13979/files#r2141701738 This is a big overhaul to bring consistency between how inputs and per- layer components are created for attention layers and recurrent layers. The main changes are: - Rename class llm_graph_input_s_copy -> llm_graph_input_rs - Add a corresponding llm_graph_input_rs_hybrid_recurrent - Rename build_inp_s_copy -> build_rs_inp_recurrent - Add a corresponding build_rs_inp_hybrid_recurrent - Rename build_recurrent_state -> build_rs to match build_attn w/ llm_graph_input_rs android-build AUTHORS bamba-9b-2.2T.gguf bamba-9b-2.2T.q4_k_m.gguf broken.log build build-rel build-xcframework.sh build.android build.android.bak ci cmake CMakeLists.txt CMakePresets.json CODEOWNERS common common.o CONTRIBUTING.md convert_hf_to_gguf_update.py convert_hf_to_gguf.py convert_llama_ggml_to_gguf.py convert_lora_to_gguf.py debug.log docs examples flake.lock flake.nix ggml ggml-alloc.o ggml-backend.o ggml-metal.o ggml-model-BF16.gguf ggml-model-Q4_K_M.gguf ggml-quants.o ggml.o gguf-py grammar-parser.o grammars include LICENSE licenses llama.log llama.o llamacpp_trace.log main.log Makefile media models mypy.ini pocs poetry.lock prompts pyproject.toml pyrightconfig.json q4_k_m_boot.log q8_0_boot.log quant.log quant2.log README.md requirements requirements.txt sampling.o scripts SECURITY.md src test-grammar-output.tmp test-json-schema-input.tmp tests tools vendor working.log as the first input - Add a corresponding overload of build_rs w/ llm_graph_input_rs_hybrid_recurrent android-build AUTHORS bamba-9b-2.2T.gguf bamba-9b-2.2T.q4_k_m.gguf broken.log build build-rel build-xcframework.sh build.android build.android.bak ci cmake CMakeLists.txt CMakePresets.json CODEOWNERS common common.o CONTRIBUTING.md convert_hf_to_gguf_update.py convert_hf_to_gguf.py convert_llama_ggml_to_gguf.py convert_lora_to_gguf.py debug.log docs examples flake.lock flake.nix ggml ggml-alloc.o ggml-backend.o ggml-metal.o ggml-model-BF16.gguf ggml-model-Q4_K_M.gguf ggml-quants.o ggml.o gguf-py grammar-parser.o grammars include LICENSE licenses llama.log llama.o llamacpp_trace.log main.log Makefile media models mypy.ini pocs poetry.lock prompts pyproject.toml pyrightconfig.json q4_k_m_boot.log q8_0_boot.log quant.log quant2.log README.md requirements requirements.txt sampling.o scripts SECURITY.md src test-grammar-output.tmp test-json-schema-input.tmp tests tools vendor working.log as the first input - Add a llm_graph_input_attn_kv_hybrid_recurrent analogous to llm_graph_input_attn_kv_unified - Add a build_attn override that takes llm_graph_input_attn_kv_hybrid_recurrent android-build AUTHORS bamba-9b-2.2T.gguf bamba-9b-2.2T.q4_k_m.gguf broken.log build build-rel build-xcframework.sh build.android build.android.bak ci cmake CMakeLists.txt CMakePresets.json CODEOWNERS common common.o CONTRIBUTING.md convert_hf_to_gguf_update.py convert_hf_to_gguf.py convert_llama_ggml_to_gguf.py convert_lora_to_gguf.py debug.log docs examples flake.lock flake.nix ggml ggml-alloc.o ggml-backend.o ggml-metal.o ggml-model-BF16.gguf ggml-model-Q4_K_M.gguf ggml-quants.o ggml.o gguf-py grammar-parser.o grammars include LICENSE licenses llama.log llama.o llamacpp_trace.log main.log Makefile media models mypy.ini pocs poetry.lock prompts pyproject.toml pyrightconfig.json q4_k_m_boot.log q8_0_boot.log quant.log quant2.log README.md requirements requirements.txt sampling.o scripts SECURITY.md src test-grammar-output.tmp test-json-schema-input.tmp tests tools vendor working.log as the first input This makes the two paradigms fully consistent. The main drawback is the code duplication in the build_attn and build_rs implementations where the only difference between implementations is how they cast the memory state. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
parent
4ec4e6a801
commit
11cd80d5de
|
|
@ -239,7 +239,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
|||
}
|
||||
}
|
||||
|
||||
void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
|
||||
void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
|
||||
GGML_UNUSED(ubatch);
|
||||
|
||||
const int64_t n_kv = kv_state->get_n_kv();
|
||||
|
|
@ -255,6 +255,11 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
|
|||
}
|
||||
}
|
||||
|
||||
llm_graph_input_rs_hybrid_recurrent::llm_graph_input_rs_hybrid_recurrent(
|
||||
const llama_kv_cache_hybrid_recurrent_state * kv_state) :
|
||||
llm_graph_input_rs(kv_state->get_state_recurrent()) {
|
||||
}
|
||||
|
||||
void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
|
||||
GGML_UNUSED(ubatch);
|
||||
|
||||
|
|
@ -360,6 +365,13 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
|||
}
|
||||
}
|
||||
|
||||
llm_graph_input_attn_kv_hybrid_recurrent::llm_graph_input_attn_kv_hybrid_recurrent(
|
||||
const llama_hparams & hparams,
|
||||
const llama_cparams & cparams,
|
||||
const llama_kv_cache_hybrid_recurrent_state * kv_state) :
|
||||
llm_graph_input_attn_kv_unified(hparams, cparams, kv_state->get_state_attn()) {
|
||||
}
|
||||
|
||||
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
||||
if (self_kq_mask) {
|
||||
kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
||||
|
|
@ -962,25 +974,6 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
|
|||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_s_copy(const llama_kv_cache_recurrent_state * kv_state) const {
|
||||
if (kv_state == nullptr) {
|
||||
kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
}
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
|
||||
|
||||
const auto n_kv = kv_state->get_n_kv();
|
||||
|
||||
auto & cur = inp->s_copy;
|
||||
|
||||
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
|
||||
ggml_set_input(cur);
|
||||
|
||||
res->add_input(std::move(inp));
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
|
||||
auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
|
||||
|
||||
|
|
@ -1262,9 +1255,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
ggml_build_forward_expand(gf, k_cur);
|
||||
ggml_build_forward_expand(gf, v_cur);
|
||||
|
||||
// NOTE: For hybrid caches, this may be a child of mstate, so we use the one
|
||||
// encapsulated in inp
|
||||
const auto * kv_state = inp->kv_state;
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
|
||||
|
||||
// store to KV cache
|
||||
{
|
||||
|
|
@ -1296,15 +1287,14 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
return cur;
|
||||
}
|
||||
|
||||
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const {
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate);
|
||||
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state->get_state_attn());
|
||||
llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const {
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_hybrid_recurrent>(
|
||||
hparams, cparams, static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate));
|
||||
|
||||
{
|
||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
|
||||
|
||||
const auto n_kv = kv_state->get_state_attn()->get_n_kv();
|
||||
const auto n_kv = inp->kv_state->get_n_kv();
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
||||
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
||||
|
|
@ -1313,7 +1303,57 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_hybrid_re
|
|||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
}
|
||||
|
||||
return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
|
||||
return (llm_graph_input_attn_kv_hybrid_recurrent *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_attn(
|
||||
llm_graph_input_attn_kv_hybrid_recurrent * inp,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur,
|
||||
ggml_tensor * k_cur,
|
||||
ggml_tensor * v_cur,
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * v_mla,
|
||||
float kq_scale,
|
||||
int il) const {
|
||||
// these nodes are added to the graph together so that they are not reordered
|
||||
// by doing so, the number of splits in the graph is reduced
|
||||
ggml_build_forward_expand(gf, q_cur);
|
||||
ggml_build_forward_expand(gf, k_cur);
|
||||
ggml_build_forward_expand(gf, v_cur);
|
||||
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate)->get_state_attn();
|
||||
|
||||
// store to KV cache
|
||||
{
|
||||
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
|
||||
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
|
||||
}
|
||||
|
||||
const auto & kq_mask = inp->get_kq_mask();
|
||||
|
||||
ggml_tensor * q = q_cur;
|
||||
ggml_tensor * k = kv_state->get_k(ctx0, il);
|
||||
ggml_tensor * v = kv_state->get_v(ctx0, il);
|
||||
|
||||
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
||||
if (wo) {
|
||||
cur = build_lora_mm(wo, cur);
|
||||
if (arch == LLM_ARCH_GLM4) {
|
||||
// GLM4 seems to have numerical issues with half-precision accumulators
|
||||
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
||||
}
|
||||
}
|
||||
|
||||
if (wo_b) {
|
||||
cur = ggml_add(ctx0, cur, wo_b);
|
||||
}
|
||||
|
||||
return cur;
|
||||
}
|
||||
|
||||
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
||||
|
|
@ -1455,18 +1495,30 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_recurrent_state(
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * s,
|
||||
ggml_tensor * state_copy,
|
||||
int32_t state_size,
|
||||
int32_t n_seqs,
|
||||
bool avoid_copies,
|
||||
const llama_kv_cache_recurrent_state * kv_state) const {
|
||||
llm_graph_input_rs * llm_graph_context::build_rs_inp_recurrent() const {
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
|
||||
if (kv_state == nullptr) {
|
||||
kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
}
|
||||
auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
|
||||
|
||||
const auto n_kv = kv_state->get_n_kv();
|
||||
|
||||
auto & cur = inp->s_copy;
|
||||
|
||||
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
|
||||
ggml_set_input(cur);
|
||||
|
||||
return (llm_graph_input_rs *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_rs(
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * s,
|
||||
int32_t state_size,
|
||||
int32_t n_seqs,
|
||||
bool avoid_copies) const {
|
||||
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
|
||||
const auto n_kv = kv_state->get_n_kv();
|
||||
const auto kv_head = kv_state->get_head();
|
||||
|
|
@ -1485,7 +1537,7 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
|
|||
// copy states
|
||||
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
|
||||
// {state_size, kv_size} -> {state_size, n_seqs}
|
||||
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
|
||||
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0));
|
||||
ggml_build_forward_expand(gf, output_states);
|
||||
} else {
|
||||
// FIXME: make the gathering operation happen before the copy below
|
||||
|
|
@ -1494,7 +1546,66 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
|
|||
}
|
||||
|
||||
// copy extra states which won't be changed further (between n_seqs and n_kv)
|
||||
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
|
||||
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_kv - n_seqs, n_seqs*inp->s_copy->nb[0]));
|
||||
ggml_build_forward_expand(gf,
|
||||
ggml_cpy(ctx0,
|
||||
states_extra,
|
||||
ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s))));
|
||||
|
||||
return output_states;
|
||||
}
|
||||
|
||||
llm_graph_input_rs_hybrid_recurrent * llm_graph_context::build_rs_inp_hybrid_recurrent() const {
|
||||
auto inp = std::make_unique<llm_graph_input_rs_hybrid_recurrent>(
|
||||
static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate));
|
||||
|
||||
const auto n_kv = inp->kv_state->get_n_kv();
|
||||
|
||||
auto & cur = inp->s_copy;
|
||||
|
||||
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
|
||||
ggml_set_input(cur);
|
||||
|
||||
return (llm_graph_input_rs_hybrid_recurrent *) res->add_input(std::move(inp));
|
||||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_rs(
|
||||
llm_graph_input_rs_hybrid_recurrent * inp,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * s,
|
||||
int32_t state_size,
|
||||
int32_t n_seqs,
|
||||
bool avoid_copies) const {
|
||||
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate)->get_state_recurrent();
|
||||
|
||||
const auto n_kv = kv_state->get_n_kv();
|
||||
const auto kv_head = kv_state->get_head();
|
||||
const auto rs_zero = kv_state->get_rs_z();
|
||||
|
||||
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size());
|
||||
|
||||
// Clear a single state which will then be copied to the other cleared states.
|
||||
// Note that this is a no-op when the view is zero-sized.
|
||||
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
|
||||
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
|
||||
|
||||
ggml_tensor * output_states;
|
||||
|
||||
if (!avoid_copies) {
|
||||
// copy states
|
||||
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
|
||||
// {state_size, kv_size} -> {state_size, n_seqs}
|
||||
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0));
|
||||
ggml_build_forward_expand(gf, output_states);
|
||||
} else {
|
||||
// FIXME: make the gathering operation happen before the copy below
|
||||
// (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
|
||||
output_states = states;
|
||||
}
|
||||
|
||||
// copy extra states which won't be changed further (between n_seqs and n_kv)
|
||||
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_kv - n_seqs, n_seqs*inp->s_copy->nb[0]));
|
||||
ggml_build_forward_expand(gf,
|
||||
ggml_cpy(ctx0,
|
||||
states_extra,
|
||||
|
|
@ -1504,9 +1615,9 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
|
|||
}
|
||||
|
||||
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * state_copy,
|
||||
const llama_ubatch & ubatch,
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_cgraph * gf,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
|
||||
|
|
@ -1516,8 +1627,8 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
|||
|
||||
ggml_tensor * token_shift_all = kv_state->get_k_l(il);
|
||||
|
||||
ggml_tensor * token_shift = build_recurrent_state(
|
||||
gf, token_shift_all, state_copy,
|
||||
ggml_tensor * token_shift = build_rs(
|
||||
inp, gf, token_shift_all,
|
||||
hparams.n_embd_k_s(), n_seqs);
|
||||
|
||||
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
|
||||
|
|
|
|||
|
|
@ -189,10 +189,10 @@ public:
|
|||
const llama_cparams & cparams;
|
||||
};
|
||||
|
||||
class llm_graph_input_s_copy : public llm_graph_input_i {
|
||||
class llm_graph_input_rs : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_s_copy(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
|
||||
virtual ~llm_graph_input_s_copy() = default;
|
||||
llm_graph_input_rs(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {}
|
||||
virtual ~llm_graph_input_rs() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
|
||||
|
|
@ -201,6 +201,12 @@ public:
|
|||
const llama_kv_cache_recurrent_state * kv_state;
|
||||
};
|
||||
|
||||
class llm_graph_input_rs_hybrid_recurrent : public llm_graph_input_rs {
|
||||
public:
|
||||
llm_graph_input_rs_hybrid_recurrent(const llama_kv_cache_hybrid_recurrent_state * kv_state);
|
||||
virtual ~llm_graph_input_rs_hybrid_recurrent() = default;
|
||||
};
|
||||
|
||||
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_cross_embd(
|
||||
|
|
@ -258,6 +264,15 @@ public:
|
|||
const llama_kv_cache_unified_state * kv_state;
|
||||
};
|
||||
|
||||
class llm_graph_input_attn_kv_hybrid_recurrent : public llm_graph_input_attn_kv_unified {
|
||||
public:
|
||||
llm_graph_input_attn_kv_hybrid_recurrent(
|
||||
const llama_hparams & hparams,
|
||||
const llama_cparams & cparams,
|
||||
const llama_kv_cache_hybrid_recurrent_state * kv_state);
|
||||
virtual ~llm_graph_input_attn_kv_hybrid_recurrent() = default;
|
||||
};
|
||||
|
||||
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_attn_kv_unified_iswa(
|
||||
|
|
@ -509,7 +524,6 @@ struct llm_graph_context {
|
|||
ggml_tensor * build_inp_out_ids() const;
|
||||
ggml_tensor * build_inp_mean() const;
|
||||
ggml_tensor * build_inp_cls() const;
|
||||
ggml_tensor * build_inp_s_copy(const llama_kv_cache_recurrent_state * kv_state = nullptr) const;
|
||||
|
||||
ggml_tensor * build_inp_cross_embd() const;
|
||||
ggml_tensor * build_inp_pos_bucket_enc() const;
|
||||
|
|
@ -575,8 +589,6 @@ struct llm_graph_context {
|
|||
float kq_scale,
|
||||
int il) const;
|
||||
|
||||
llm_graph_input_attn_kv_unified * build_attn_inp_kv_hybrid_recurrent() const;
|
||||
|
||||
llm_graph_input_attn_cross * build_attn_inp_cross() const;
|
||||
|
||||
ggml_tensor * build_attn(
|
||||
|
|
@ -592,23 +604,48 @@ struct llm_graph_context {
|
|||
float kq_scale,
|
||||
int il) const;
|
||||
|
||||
llm_graph_input_attn_kv_hybrid_recurrent * build_attn_inp_kv_hybrid_recurrent() const;
|
||||
|
||||
ggml_tensor * build_attn(
|
||||
llm_graph_input_attn_kv_hybrid_recurrent * inp,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * wo,
|
||||
ggml_tensor * wo_b,
|
||||
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
||||
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
||||
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
||||
ggml_tensor * kq_b,
|
||||
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
||||
float kq_scale,
|
||||
int il) const;
|
||||
//
|
||||
// recurrent
|
||||
//
|
||||
|
||||
ggml_tensor * build_recurrent_state(
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * s,
|
||||
ggml_tensor * state_copy,
|
||||
int32_t state_size,
|
||||
int32_t n_seqs,
|
||||
bool avoid_copies = false,
|
||||
const llama_kv_cache_recurrent_state * kv_state = nullptr) const;
|
||||
llm_graph_input_rs * build_rs_inp_recurrent() const;
|
||||
|
||||
ggml_tensor * build_rs(
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * s,
|
||||
int32_t state_size,
|
||||
int32_t n_seqs,
|
||||
bool avoid_copies = false) const;
|
||||
|
||||
llm_graph_input_rs_hybrid_recurrent * build_rs_inp_hybrid_recurrent() const;
|
||||
|
||||
ggml_tensor * build_rs(
|
||||
llm_graph_input_rs_hybrid_recurrent * inp,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * s,
|
||||
int32_t state_size,
|
||||
int32_t n_seqs,
|
||||
bool avoid_copies = false) const;
|
||||
|
||||
ggml_tensor * build_rwkv_token_shift_load(
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * state_copy,
|
||||
const llama_ubatch & ubatch,
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_cgraph * gf,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const;
|
||||
|
||||
ggml_tensor * build_rwkv_token_shift_store(
|
||||
|
|
|
|||
|
|
@ -9116,7 +9116,7 @@ struct llm_build_mamba : public llm_graph_context {
|
|||
// {n_embd, n_tokens}
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
ggml_tensor * state_copy = build_inp_s_copy();
|
||||
auto * rs_inp = build_rs_inp_recurrent();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
// norm
|
||||
|
|
@ -9125,7 +9125,7 @@ struct llm_build_mamba : public llm_graph_context {
|
|||
LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
cur = build_mamba_layer(gf, cur, state_copy, ubatch, il);
|
||||
cur = build_mamba_layer(rs_inp, gf, cur, ubatch, il);
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
|
|
@ -9163,11 +9163,11 @@ struct llm_build_mamba : public llm_graph_context {
|
|||
|
||||
// TODO: split
|
||||
ggml_tensor * build_mamba_layer(
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * state_copy,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * cur,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
|
||||
const auto kv_head = kv_state->get_head();
|
||||
|
|
@ -9192,12 +9192,12 @@ struct llm_build_mamba : public llm_graph_context {
|
|||
ggml_tensor * ssm_states_all = kv_state->get_v_l(il);
|
||||
|
||||
// (ab)using the KV cache to store the states
|
||||
ggml_tensor * conv = build_recurrent_state(
|
||||
gf, conv_states_all, state_copy,
|
||||
ggml_tensor * conv = build_rs(
|
||||
inp, gf, conv_states_all,
|
||||
hparams.n_embd_k_s(), n_seqs);
|
||||
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
|
||||
ggml_tensor * ssm = build_recurrent_state(
|
||||
gf, ssm_states_all, state_copy,
|
||||
ggml_tensor * ssm = build_rs(
|
||||
inp, gf, ssm_states_all,
|
||||
hparams.n_embd_v_s(), n_seqs);
|
||||
ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs);
|
||||
|
||||
|
|
@ -11909,10 +11909,10 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
|||
}
|
||||
|
||||
ggml_tensor * build_rwkv6_time_mix(
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * x_prev,
|
||||
ggml_tensor * state_copy,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
||||
|
|
@ -12036,8 +12036,8 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
|||
k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w));
|
||||
}
|
||||
|
||||
ggml_tensor * wkv_state = build_recurrent_state(
|
||||
gf, kv_state->get_v_l(il), state_copy,
|
||||
ggml_tensor * wkv_state = build_rs(
|
||||
inp, gf, kv_state->get_v_l(il),
|
||||
hparams.n_embd_v_s(), n_seqs);
|
||||
|
||||
ggml_tensor * wkv_output;
|
||||
|
|
@ -12092,7 +12092,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
|
|||
inpL = build_inp_embd(model.tok_embd);
|
||||
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
|
||||
|
||||
ggml_tensor * state_copy = build_inp_s_copy();
|
||||
auto * rs_inp = build_rs_inp_recurrent();
|
||||
|
||||
const auto n_embd = hparams.n_embd;
|
||||
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
||||
|
|
@ -12102,9 +12102,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
|
|||
const llama_layer * layer = &model.layers[il];
|
||||
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
||||
|
||||
ggml_tensor * token_shift = build_rwkv_token_shift_load(
|
||||
gf, state_copy, ubatch, il
|
||||
);
|
||||
ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
|
||||
|
||||
ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
|
||||
ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
|
||||
|
|
@ -12119,7 +12117,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
|
|||
1
|
||||
);
|
||||
|
||||
cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il);
|
||||
cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il);
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
|
@ -12189,7 +12187,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
|
|||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
ggml_tensor * state_copy = build_inp_s_copy();
|
||||
auto * rs_inp = build_rs_inp_recurrent();
|
||||
|
||||
const auto n_embd = hparams.n_embd;
|
||||
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
||||
|
|
@ -12199,9 +12197,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
|
|||
const llama_layer * layer = &model.layers[il];
|
||||
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
||||
|
||||
ggml_tensor * token_shift = build_rwkv_token_shift_load(
|
||||
gf, state_copy, ubatch, il
|
||||
);
|
||||
ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
|
||||
|
||||
ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
|
||||
cb(att_norm, "attn_norm", il);
|
||||
|
|
@ -12213,7 +12209,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
|
|||
1
|
||||
);
|
||||
|
||||
cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il);
|
||||
cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il);
|
||||
|
||||
token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
|
||||
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
|
||||
|
|
@ -12301,10 +12297,10 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
|||
}
|
||||
|
||||
ggml_tensor * build_rwkv7_time_mix(
|
||||
llm_graph_input_rs * inp,
|
||||
ggml_cgraph * gf,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * x_prev,
|
||||
ggml_tensor * state_copy,
|
||||
ggml_tensor *& first_layer_value,
|
||||
const llama_ubatch & ubatch,
|
||||
int il) const {
|
||||
|
|
@ -12387,8 +12383,8 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
|||
v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens);
|
||||
a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
|
||||
|
||||
ggml_tensor * wkv_state = build_recurrent_state(
|
||||
gf, kv_state->get_v_l(il), state_copy,
|
||||
ggml_tensor * wkv_state = build_rs(
|
||||
inp, gf, kv_state->get_v_l(il),
|
||||
hparams.n_embd_v_s(), n_seqs);
|
||||
|
||||
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
|
||||
|
|
@ -12445,7 +12441,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
|
|||
inpL = build_inp_embd(model.tok_embd);
|
||||
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
|
||||
|
||||
ggml_tensor * state_copy = build_inp_s_copy();
|
||||
auto * rs_inp = build_rs_inp_recurrent();
|
||||
|
||||
const auto n_embd = hparams.n_embd;
|
||||
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
||||
|
|
@ -12455,9 +12451,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
|
|||
const llama_layer * layer = &model.layers[il];
|
||||
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
||||
|
||||
ggml_tensor * token_shift = build_rwkv_token_shift_load(
|
||||
gf, state_copy, ubatch, il
|
||||
);
|
||||
ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
|
||||
|
||||
ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
|
||||
ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
|
||||
|
|
@ -12472,7 +12466,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
|
|||
1
|
||||
);
|
||||
|
||||
cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il);
|
||||
cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il);
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
|
@ -12538,7 +12532,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
|
|||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
ggml_tensor * state_copy = build_inp_s_copy();
|
||||
auto * rs_inp = build_rs_inp_recurrent();
|
||||
|
||||
const auto n_embd = hparams.n_embd;
|
||||
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
||||
|
|
@ -12548,9 +12542,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
|
|||
const llama_layer * layer = &model.layers[il];
|
||||
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
||||
|
||||
ggml_tensor * token_shift = build_rwkv_token_shift_load(
|
||||
gf, state_copy, ubatch, il
|
||||
);
|
||||
ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
|
||||
|
||||
ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
|
||||
cb(att_norm, "attn_norm", il);
|
||||
|
|
@ -12562,7 +12554,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
|
|||
1
|
||||
);
|
||||
|
||||
cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il);
|
||||
cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il);
|
||||
|
||||
token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
|
||||
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
|
||||
|
|
|
|||
Loading…
Reference in New Issue