diff --git a/include/llama.h b/include/llama.h index a940f9d648..cb5eb82c2f 100644 --- a/include/llama.h +++ b/include/llama.h @@ -883,6 +883,9 @@ extern "C" { // work only with partial states, such as SWA KV cache or recurrent cache (e.g. Mamba) #define LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY 1 +// restore without clearing existing sequence data (append to existing KV cache entries) +#define LLAMA_STATE_SEQ_FLAGS_APPEND 2 + typedef uint32_t llama_state_seq_flags; LLAMA_API size_t llama_state_seq_get_size_ext( diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 3e0fd3107f..837e661faa 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1906,7 +1906,7 @@ void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama slot_info sinfo; bool res = true; - res = res && state_read_meta(io, strm, cell_count, sinfo, seq_id); + res = res && state_read_meta(io, strm, cell_count, sinfo, seq_id, flags); res = res && state_read_data(io, strm, cell_count, sinfo); if (!res) { @@ -2052,13 +2052,15 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t } } -bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id) { +bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id, llama_state_seq_flags flags) { auto & cells = v_cells[strm]; auto & head = v_heads[strm]; if (dest_seq_id != -1) { // single sequence - seq_rm(dest_seq_id, -1, -1); + if (!(flags & LLAMA_STATE_SEQ_FLAGS_APPEND)) { + seq_rm(dest_seq_id, -1, -1); + } llama_batch_allocr balloc(hparams.n_pos_per_embd()); diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index d4569a06f7..9d396206ec 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -297,7 +297,7 @@ private: void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const; void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const; - bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id = -1); + bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id = -1, llama_state_seq_flags flags = 0); bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, const slot_info & sinfo); };