diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 70a7114f39..0ec3a8a501 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -65,10 +65,7 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const { return n_embd_head_v * n_head_kv; } -uint32_t llama_hparams::n_embd_k_s(uint32_t il) const { - if (!recurrent_layer(il)) { - return 0; - } +uint32_t llama_hparams::n_embd_k_s() const { if (wkv_head_size != 0) { // for RWKV models return token_shift_count * n_embd; @@ -79,10 +76,7 @@ uint32_t llama_hparams::n_embd_k_s(uint32_t il) const { return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; } -uint32_t llama_hparams::n_embd_v_s(uint32_t il) const { - if (!recurrent_layer(il)) { - return 0; - } +uint32_t llama_hparams::n_embd_v_s() const { if (wkv_head_size != 0) { // corresponds to RWKV's wkv_states size return n_embd * wkv_head_size; diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 3614596464..84234494c5 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -184,10 +184,10 @@ struct llama_hparams { // dimension of the rolling state embeddings // corresponds to Mamba's conv_states size or RWKV's token_shift states size - uint32_t n_embd_k_s(uint32_t il = 0) const; + uint32_t n_embd_k_s() const; // dimension of the recurrent state embeddings - uint32_t n_embd_v_s(uint32_t il = 0) const; + uint32_t n_embd_v_s() const; // whether or not the given layer is recurrent (for hybrid models) bool recurrent_layer(uint32_t il) const; diff --git a/src/llama-kv-cache-recurrent.cpp b/src/llama-kv-cache-recurrent.cpp index 672f0197d0..917d2a60c9 100644 --- a/src/llama-kv-cache-recurrent.cpp +++ b/src/llama-kv-cache-recurrent.cpp @@ -69,8 +69,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( continue; } - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(i); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(i); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); const char * dev_name = "CPU"; @@ -754,7 +754,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std // Iterate and write all the keys first, each row is a cell // Get whole range at a time for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); // Write key type const int32_t k_type_i = (int32_t)k_l[il]->type; @@ -774,7 +774,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std if (!v_trans) { for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); // Write value type const int32_t v_type_i = (int32_t)v_l[il]->type; @@ -795,7 +795,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std // When v is transposed, we also need the element size and get the element ranges from each row const uint32_t kv_size = size; for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); // Write value type const int32_t v_type_i = (int32_t)v_l[il]->type; @@ -942,7 +942,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); // Read type of key int32_t k_type_i_ref; @@ -970,7 +970,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce if (!v_trans) { for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); // Read type of value int32_t v_type_i_ref; @@ -998,7 +998,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce } else { // For each layer, read the values for each cell (transposed) for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); // Read type of value int32_t v_type_i_ref; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 34643226e5..6e9dd53223 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -9194,11 +9194,11 @@ struct llm_build_mamba : public llm_graph_context { // (ab)using the KV cache to store the states ggml_tensor * conv = build_recurrent_state( gf, conv_states_all, state_copy, - hparams.n_embd_k_s(il), n_seqs); + hparams.n_embd_k_s(), n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs); ggml_tensor * ssm = build_recurrent_state( gf, ssm_states_all, state_copy, - hparams.n_embd_v_s(il), n_seqs); + hparams.n_embd_v_s(), n_seqs); ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}