This commit is contained in:
Samuel Shen 2026-04-03 12:30:22 +03:00 committed by GitHub
commit 34140b8b24
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 9 additions and 4 deletions

View File

@ -883,6 +883,9 @@ extern "C" {
// work only with partial states, such as SWA KV cache or recurrent cache (e.g. Mamba)
#define LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY 1
// restore without clearing existing sequence data (append to existing KV cache entries)
#define LLAMA_STATE_SEQ_FLAGS_APPEND 2
typedef uint32_t llama_state_seq_flags;
LLAMA_API size_t llama_state_seq_get_size_ext(

View File

@ -1906,7 +1906,7 @@ void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama
slot_info sinfo;
bool res = true;
res = res && state_read_meta(io, strm, cell_count, sinfo, seq_id);
res = res && state_read_meta(io, strm, cell_count, sinfo, seq_id, flags);
res = res && state_read_data(io, strm, cell_count, sinfo);
if (!res) {
@ -2052,13 +2052,15 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t
}
}
bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id) {
bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id, llama_state_seq_flags flags) {
auto & cells = v_cells[strm];
auto & head = v_heads[strm];
if (dest_seq_id != -1) {
// single sequence
seq_rm(dest_seq_id, -1, -1);
if (!(flags & LLAMA_STATE_SEQ_FLAGS_APPEND)) {
seq_rm(dest_seq_id, -1, -1);
}
llama_batch_allocr balloc(hparams.n_pos_per_embd());

View File

@ -297,7 +297,7 @@ private:
void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const;
void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const;
bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id = -1);
bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id = -1, llama_state_seq_flags flags = 0);
bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, const slot_info & sinfo);
};