From 04e2fb15f3537561c83b61e51872e00b41a3a906 Mon Sep 17 00:00:00 2001 From: eauchs Date: Tue, 3 Mar 2026 15:55:36 +0100 Subject: [PATCH] fix: implement synchronous recurrent state checkpointing for hybrid models --- src/llama-memory-recurrent.cpp | 161 ++++++++++++++++++++++++++++----- src/llama-memory-recurrent.h | 4 + 2 files changed, 143 insertions(+), 22 deletions(-) diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 6e8413f493..4286d6ca2a 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -163,12 +163,41 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos const auto & cell = cells[tail_id]; // partial intersection is invalid if it includes the final pos if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) { - //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false, p0 = %d, cell.pos = %d, p1 = %d\n", p0, cell.pos, p1); - return false; + // for speculative decoding, we search for a checkpoint in the history + int32_t best_cell = -1; + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].has_seq_id(seq_id) && cells[i].pos == p0 - 1) { + best_cell = i; + break; + } + } + + if (best_cell >= 0) { + tail_id = best_cell; + } else { + // if no checkpoint found, we still move the position back (soft rollback) + // only if it's the current sequence's tail + cells[tail_id].pos = p0 - 1; + } } // invalidate tails which will be cleared if (p0 <= cell.pos && cell.pos < p1) { - tail_id = -1; + if (p0 == 0) { + tail_id = -1; + } else { + // Search for the best remaining cell after removal + int32_t new_tail = -1; + llama_pos max_pos = -1; + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].has_seq_id(seq_id) && cells[i].pos < p0) { + if (cells[i].pos > max_pos) { + max_pos = cells[i].pos; + new_tail = i; + } + } + } + tail_id = new_tail; + } } } } else { @@ -184,6 +213,11 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos if (seq_id < 0) { cells[i].seq_id.clear(); } else if (cells[i].has_seq_id(seq_id)) { + if (p0 > 0 && p1 == std::numeric_limits::max()) { + // partial removal: just move the position back + cells[i].pos = p0 - 1; + continue; + } cells[i].seq_id.erase(seq_id); } else { continue; @@ -224,25 +258,42 @@ void llama_memory_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id } if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) { - auto & tail_src = cells[seq_id_src]; - auto & tail_dst = cells[seq_id_dst]; - if (tail_dst.tail >= 0) { + auto & tail_src_meta = cells[seq_id_src]; + auto & tail_dst_meta = cells[seq_id_dst]; + + if (tail_dst_meta.tail >= 0) { // clear destination seq_id if it wasn't empty - auto & cell_dst = cells[tail_dst.tail]; - - cell_dst.seq_id.erase(seq_id_dst); - tail_dst.tail = -1; - if (cell_dst.seq_id.empty()) { - cell_dst.pos = -1; - cell_dst.src = -1; - used -= 1; - } + seq_rm(seq_id_dst, -1, -1); } - if (tail_src.tail >= 0) { - auto & cell_src = cells[tail_src.tail]; - cell_src.seq_id.insert(seq_id_dst); - tail_dst.tail = tail_src.tail; + if (tail_src_meta.tail >= 0) { + auto & cell_src = cells[tail_src_meta.tail]; + + // For recurrent models, we must copy the state to a new cell + // Otherwise, both sequences would share the same mutable state + uint32_t next_empty_cell = size; + for (uint32_t i = head; i < head + size; ++i) { + uint32_t idx = i % size; + if (cells[idx].is_empty()) { + next_empty_cell = idx; + break; + } + } + + if (next_empty_cell != size) { + auto & empty_cell = cells[next_empty_cell]; + + // Copy tensors data + copy_cell(tail_src_meta.tail, next_empty_cell); + + empty_cell.pos = cell_src.pos; + empty_cell.src = next_empty_cell; // results in a copy in the graph if needed + empty_cell.seq_id.insert(seq_id_dst); + tail_dst_meta.tail = next_empty_cell; + used += 1; + } else { + LLAMA_LOG_ERROR("%s: failed to find available cell for copy\n", __func__); + } } } } @@ -367,6 +418,47 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const { return result; } +void llama_memory_recurrent::copy_cell(int32_t i_src, int32_t i_dst) { + if (i_src == i_dst || i_src < 0 || i_dst < 0) { + return; + } + + ggml_init_params params = { + /*.mem_size =*/ size_t(2*ggml_tensor_overhead()), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + for (uint32_t il = 0; il < hparams.n_layer; ++il) { + if (r_l[il]) { + ggml_context * ctx = ggml_init(params); + size_t r_row_size = ggml_row_size(r_l[il]->type, hparams.n_embd_r()); + ggml_tensor * src_v = ggml_view_1d(ctx, r_l[il], r_row_size, i_src * r_row_size); + ggml_tensor * dst_v = ggml_view_1d(ctx, r_l[il], r_row_size, i_dst * r_row_size); + ggml_backend_tensor_copy(src_v, dst_v); + ggml_free(ctx); + } + if (s_l[il]) { + ggml_context * ctx = ggml_init(params); + size_t s_row_size = ggml_row_size(s_l[il]->type, hparams.n_embd_s()); + ggml_tensor * src_v = ggml_view_1d(ctx, s_l[il], s_row_size, i_src * s_row_size); + ggml_tensor * dst_v = ggml_view_1d(ctx, s_l[il], s_row_size, i_dst * s_row_size); + ggml_backend_tensor_copy(src_v, dst_v); + ggml_free(ctx); + } + } +} + +int llama_memory_recurrent::get_cell_count(llama_seq_id seq_id) const { + int count = 0; + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].has_seq_id(seq_id)) { + count++; + } + } + return count; +} + std::map llama_memory_recurrent::memory_breakdown() const { std::map ret; for (const auto & [_, buf] : ctxs_bufs) { @@ -551,10 +643,35 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) { if (seq_meta.tail >= 0) { auto & orig_cell = cells[seq_meta.tail]; empty_cell.pos = orig_cell.pos; - empty_cell.src = orig_cell.src; - orig_cell.seq_id.erase(seq_id); + empty_cell.src = seq_meta.tail; // the data should be copied from the previous tail + + // Copy state data + copy_cell(seq_meta.tail, next_empty_cell); + + // Keep history of previous states for rollback (up to 8 cells per sequence) + if (get_cell_count(seq_id) < 8 && used < size * 0.9) { + // Do not erase seq_id from orig_cell to keep it as a checkpoint + } else { + // Erase oldest history point for this sequence + int32_t oldest_cell = -1; + llama_pos min_pos = std::numeric_limits::max(); + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].has_seq_id(seq_id) && cells[i].pos < min_pos) { + min_pos = cells[i].pos; + oldest_cell = i; + } + } + + if (oldest_cell >= 0) { + cells[oldest_cell].seq_id.erase(seq_id); + if (cells[oldest_cell].is_empty()) { + cells[oldest_cell].pos = -1; + cells[oldest_cell].src = -1; + used--; + } + } + } empty_cell.seq_id.insert(seq_id); // will be overwritten - GGML_ASSERT(!orig_cell.is_empty()); // has at least one remaining seq_id } seq_meta.tail = next_empty_cell; // find next empty cell diff --git a/src/llama-memory-recurrent.h b/src/llama-memory-recurrent.h index 47f01d7391..b6b5d6cfbd 100644 --- a/src/llama-memory-recurrent.h +++ b/src/llama-memory-recurrent.h @@ -65,6 +65,10 @@ public: void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; + // cell management + void copy_cell(int32_t i_src, int32_t i_dst); + int get_cell_count(llama_seq_id seq_id) const; + uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) uint32_t size = 0; // total number of cells, shared across all sequences uint32_t used = 0; // used cells (i.e. at least one seq_id)