refactor: Remove layer index from n_embd_k/v_s
Now that it's not used at all in the unified cache, we don't need to use the layer index to zero it out for attention layers. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
parent
1dd12133cd
commit
b42c8b43cf
|
|
@ -65,10 +65,7 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
|
|||
return n_embd_head_v * n_head_kv;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_k_s(uint32_t il) const {
|
||||
if (!recurrent_layer(il)) {
|
||||
return 0;
|
||||
}
|
||||
uint32_t llama_hparams::n_embd_k_s() const {
|
||||
if (wkv_head_size != 0) {
|
||||
// for RWKV models
|
||||
return token_shift_count * n_embd;
|
||||
|
|
@ -79,10 +76,7 @@ uint32_t llama_hparams::n_embd_k_s(uint32_t il) const {
|
|||
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_v_s(uint32_t il) const {
|
||||
if (!recurrent_layer(il)) {
|
||||
return 0;
|
||||
}
|
||||
uint32_t llama_hparams::n_embd_v_s() const {
|
||||
if (wkv_head_size != 0) {
|
||||
// corresponds to RWKV's wkv_states size
|
||||
return n_embd * wkv_head_size;
|
||||
|
|
|
|||
|
|
@ -184,10 +184,10 @@ struct llama_hparams {
|
|||
|
||||
// dimension of the rolling state embeddings
|
||||
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
|
||||
uint32_t n_embd_k_s(uint32_t il = 0) const;
|
||||
uint32_t n_embd_k_s() const;
|
||||
|
||||
// dimension of the recurrent state embeddings
|
||||
uint32_t n_embd_v_s(uint32_t il = 0) const;
|
||||
uint32_t n_embd_v_s() const;
|
||||
|
||||
// whether or not the given layer is recurrent (for hybrid models)
|
||||
bool recurrent_layer(uint32_t il) const;
|
||||
|
|
|
|||
|
|
@ -69,8 +69,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
|
|||
continue;
|
||||
}
|
||||
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(i);
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(i);
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
|
||||
|
||||
const char * dev_name = "CPU";
|
||||
|
||||
|
|
@ -754,7 +754,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
|
|||
// Iterate and write all the keys first, each row is a cell
|
||||
// Get whole range at a time
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
||||
|
||||
// Write key type
|
||||
const int32_t k_type_i = (int32_t)k_l[il]->type;
|
||||
|
|
@ -774,7 +774,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
|
|||
|
||||
if (!v_trans) {
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
|
||||
// Write value type
|
||||
const int32_t v_type_i = (int32_t)v_l[il]->type;
|
||||
|
|
@ -795,7 +795,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std
|
|||
// When v is transposed, we also need the element size and get the element ranges from each row
|
||||
const uint32_t kv_size = size;
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
|
||||
// Write value type
|
||||
const int32_t v_type_i = (int32_t)v_l[il]->type;
|
||||
|
|
@ -942,7 +942,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
|
|||
|
||||
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il);
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
||||
|
||||
// Read type of key
|
||||
int32_t k_type_i_ref;
|
||||
|
|
@ -970,7 +970,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
|
|||
|
||||
if (!v_trans) {
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
|
||||
// Read type of value
|
||||
int32_t v_type_i_ref;
|
||||
|
|
@ -998,7 +998,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
|
|||
} else {
|
||||
// For each layer, read the values for each cell (transposed)
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il);
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
|
||||
// Read type of value
|
||||
int32_t v_type_i_ref;
|
||||
|
|
|
|||
|
|
@ -9194,11 +9194,11 @@ struct llm_build_mamba : public llm_graph_context {
|
|||
// (ab)using the KV cache to store the states
|
||||
ggml_tensor * conv = build_recurrent_state(
|
||||
gf, conv_states_all, state_copy,
|
||||
hparams.n_embd_k_s(il), n_seqs);
|
||||
hparams.n_embd_k_s(), n_seqs);
|
||||
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
|
||||
ggml_tensor * ssm = build_recurrent_state(
|
||||
gf, ssm_states_all, state_copy,
|
||||
hparams.n_embd_v_s(il), n_seqs);
|
||||
hparams.n_embd_v_s(), n_seqs);
|
||||
ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs);
|
||||
|
||||
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
|
||||
|
|
|
|||
Loading…
Reference in New Issue