diff --git a/src/llama.cpp b/src/llama.cpp index 510b7fe893..b076681134 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2583,7 +2583,7 @@ struct llama_hparams { return n_embd_head_v * n_head_kv; } - uint32_t n_embd_r(uint32_t il) const { // dimension of the rolling state embeddings + uint32_t n_embd_r(uint32_t il) const { // dimension of the rolling state embeddings // TODO: support using an SSM in place of the MLP of a Transformer if (n_head_kv(il) != 0) { return 0; } // corresponds to Mamba's conv_states size or RWKV's token_shift states size @@ -2597,7 +2597,7 @@ struct llama_hparams { } } - uint32_t n_embd_s(uint32_t il) const { // dimension of the recurrent state embeddings + uint32_t n_embd_s(uint32_t il) const { // dimension of the recurrent state embeddings // TODO: support using an SSM in place of the MLP of a Transformer if (n_head_kv(il) != 0) { return 0; } @@ -2875,17 +2875,13 @@ struct llama_kv_self_cache { struct llama_rs_cell { llama_pos pos = -1; - int32_t src = -1; // copy source id (cleared next when -1) + int32_t src = -1; // copy source id (cleared next when -1) std::set seq_id; - bool has_seq_id(const llama_seq_id & id) const { - return seq_id.find(id) != seq_id.end(); - } + bool has_seq_id(const llama_seq_id & id) const { return seq_id.find(id) != seq_id.end(); } - bool is_empty() const { - return seq_id.empty(); - } + bool is_empty() const { return seq_id.empty(); } }; struct llama_rs_seq_meta { @@ -2895,24 +2891,23 @@ struct llama_rs_seq_meta { // ring-buffered tree of cached recurrent state data struct llama_rs_self_cache { - - uint32_t head = 0; // first state used for the last slot + uint32_t head = 0; // first state used for the last slot uint32_t size = 0; uint32_t used = 0; // computed when finding a slot - uint32_t n = 0; // range of states used for the last slot + uint32_t n = 0; // range of states used for the last slot // with state models, a cell can hold the state for more than one past token std::vector cells; // find tail cells faster - std::vector seq_tails; // map seq_ids to cell ids + std::vector seq_tails; // map seq_ids to cell ids // per layer // NOTE: the naming of r and s is arbitrary - std::vector r_l; // rolling/shift states - std::vector s_l; // ssm (recurrent) states + std::vector r_l; // rolling/shift states + std::vector s_l; // ssm (recurrent) states // Inefficient, but thorough verification and rebuilding of the rs cache // from only the cells list with `pos` and seq_ids. @@ -2920,21 +2915,21 @@ struct llama_rs_self_cache { bool rebuild(bool debug) { bool was_valid = true; // skip for non-recurrent models - if (size == 0) { return true; } + if (size == 0) { + return true; + } // the source of truth is the cells list // buffer sizes if (size != cells.size()) { if (debug) { - LLAMA_LOG_ERROR("%s: cells has wrong size (%zu instead of %u)\n", - __func__, cells.size(), size); + LLAMA_LOG_ERROR("%s: cells has wrong size (%zu instead of %u)\n", __func__, cells.size(), size); } cells.resize(size); was_valid = false; } if (size != seq_tails.size()) { if (debug) { - LLAMA_LOG_ERROR("%s: seq_tails has wrong size (%zu instead of %u)\n", - __func__, seq_tails.size(), size); + LLAMA_LOG_ERROR("%s: seq_tails has wrong size (%zu instead of %u)\n", __func__, seq_tails.size(), size); } seq_tails.resize(size); was_valid = false; @@ -2994,7 +2989,7 @@ struct llama_rs_self_cache { for (uint32_t cell_id = 0; cell_id < size; ++cell_id) { llama_rs_cell & cell = cells[cell_id]; if (cell.has_seq_id(seq_id)) { - seq_cells.push_back({cell.pos, cell_id}); + seq_cells.push_back({ cell.pos, cell_id }); } } // sort by pos and then by cell_id @@ -3718,16 +3713,16 @@ static bool llama_kv_cache_init( } if (has_kv) { - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, hparams.n_embd_k_gqa(i)*kv_size); - ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, hparams.n_embd_v_gqa(i)*kv_size); + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, hparams.n_embd_k_gqa(i) * kv_size); + ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, hparams.n_embd_v_gqa(i) * kv_size); ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(v, "cache_v_l%d", i); cache.kv.k_l.push_back(k); cache.kv.v_l.push_back(v); } if (has_rs) { - ggml_tensor * r = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_r(i)*rs_size); - ggml_tensor * s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_s(i)*rs_size); + ggml_tensor * r = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_r(i) * rs_size); + ggml_tensor * s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_s(i) * rs_size); ggml_format_name(r, "cache_r_l%d", i); ggml_format_name(s, "cache_s_l%d", i); cache.rs.r_l.push_back(r); @@ -4370,8 +4365,8 @@ struct llama_kv_slot_restorer { bool do_restore = false; explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) { - old_state.head = cache.kv.head; - old_state.n = cache.kv.n; + old_state.head = cache.kv.head; + old_state.n = cache.kv.n; } // saves a slot information for future restoration @@ -4388,10 +4383,10 @@ struct llama_kv_slot_restorer { // and rollback changes from all llama_kv_cache_find_slot calls void restore(struct llama_kv_cache & cache) { if (do_restore) { - cache.kv.head = old_state.head; - cache.kv.n = old_state.n; + cache.kv.head = old_state.head; + cache.kv.n = old_state.n; - if (cache.rs.size > 0) { // recurrent models like Mamba or RWKV can't have a state partially erased + if (cache.rs.size > 0) { // recurrent models like Mamba or RWKV can't have a state partially erased llama_kv_cache_seq_rm(cache, -1, -1, -1); } else { for (auto & slot : slot_boundaries) {