From 04e2fb15f3537561c83b61e51872e00b41a3a906 Mon Sep 17 00:00:00 2001 From: eauchs Date: Tue, 3 Mar 2026 15:55:36 +0100 Subject: [PATCH 1/4] fix: implement synchronous recurrent state checkpointing for hybrid models --- src/llama-memory-recurrent.cpp | 161 ++++++++++++++++++++++++++++----- src/llama-memory-recurrent.h | 4 + 2 files changed, 143 insertions(+), 22 deletions(-) diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 6e8413f493..4286d6ca2a 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -163,12 +163,41 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos const auto & cell = cells[tail_id]; // partial intersection is invalid if it includes the final pos if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) { - //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false, p0 = %d, cell.pos = %d, p1 = %d\n", p0, cell.pos, p1); - return false; + // for speculative decoding, we search for a checkpoint in the history + int32_t best_cell = -1; + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].has_seq_id(seq_id) && cells[i].pos == p0 - 1) { + best_cell = i; + break; + } + } + + if (best_cell >= 0) { + tail_id = best_cell; + } else { + // if no checkpoint found, we still move the position back (soft rollback) + // only if it's the current sequence's tail + cells[tail_id].pos = p0 - 1; + } } // invalidate tails which will be cleared if (p0 <= cell.pos && cell.pos < p1) { - tail_id = -1; + if (p0 == 0) { + tail_id = -1; + } else { + // Search for the best remaining cell after removal + int32_t new_tail = -1; + llama_pos max_pos = -1; + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].has_seq_id(seq_id) && cells[i].pos < p0) { + if (cells[i].pos > max_pos) { + max_pos = cells[i].pos; + new_tail = i; + } + } + } + tail_id = new_tail; + } } } } else { @@ -184,6 +213,11 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos if (seq_id < 0) { cells[i].seq_id.clear(); } else if (cells[i].has_seq_id(seq_id)) { + if (p0 > 0 && p1 == std::numeric_limits::max()) { + // partial removal: just move the position back + cells[i].pos = p0 - 1; + continue; + } cells[i].seq_id.erase(seq_id); } else { continue; @@ -224,25 +258,42 @@ void llama_memory_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id } if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) { - auto & tail_src = cells[seq_id_src]; - auto & tail_dst = cells[seq_id_dst]; - if (tail_dst.tail >= 0) { + auto & tail_src_meta = cells[seq_id_src]; + auto & tail_dst_meta = cells[seq_id_dst]; + + if (tail_dst_meta.tail >= 0) { // clear destination seq_id if it wasn't empty - auto & cell_dst = cells[tail_dst.tail]; - - cell_dst.seq_id.erase(seq_id_dst); - tail_dst.tail = -1; - if (cell_dst.seq_id.empty()) { - cell_dst.pos = -1; - cell_dst.src = -1; - used -= 1; - } + seq_rm(seq_id_dst, -1, -1); } - if (tail_src.tail >= 0) { - auto & cell_src = cells[tail_src.tail]; - cell_src.seq_id.insert(seq_id_dst); - tail_dst.tail = tail_src.tail; + if (tail_src_meta.tail >= 0) { + auto & cell_src = cells[tail_src_meta.tail]; + + // For recurrent models, we must copy the state to a new cell + // Otherwise, both sequences would share the same mutable state + uint32_t next_empty_cell = size; + for (uint32_t i = head; i < head + size; ++i) { + uint32_t idx = i % size; + if (cells[idx].is_empty()) { + next_empty_cell = idx; + break; + } + } + + if (next_empty_cell != size) { + auto & empty_cell = cells[next_empty_cell]; + + // Copy tensors data + copy_cell(tail_src_meta.tail, next_empty_cell); + + empty_cell.pos = cell_src.pos; + empty_cell.src = next_empty_cell; // results in a copy in the graph if needed + empty_cell.seq_id.insert(seq_id_dst); + tail_dst_meta.tail = next_empty_cell; + used += 1; + } else { + LLAMA_LOG_ERROR("%s: failed to find available cell for copy\n", __func__); + } } } } @@ -367,6 +418,47 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const { return result; } +void llama_memory_recurrent::copy_cell(int32_t i_src, int32_t i_dst) { + if (i_src == i_dst || i_src < 0 || i_dst < 0) { + return; + } + + ggml_init_params params = { + /*.mem_size =*/ size_t(2*ggml_tensor_overhead()), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + for (uint32_t il = 0; il < hparams.n_layer; ++il) { + if (r_l[il]) { + ggml_context * ctx = ggml_init(params); + size_t r_row_size = ggml_row_size(r_l[il]->type, hparams.n_embd_r()); + ggml_tensor * src_v = ggml_view_1d(ctx, r_l[il], r_row_size, i_src * r_row_size); + ggml_tensor * dst_v = ggml_view_1d(ctx, r_l[il], r_row_size, i_dst * r_row_size); + ggml_backend_tensor_copy(src_v, dst_v); + ggml_free(ctx); + } + if (s_l[il]) { + ggml_context * ctx = ggml_init(params); + size_t s_row_size = ggml_row_size(s_l[il]->type, hparams.n_embd_s()); + ggml_tensor * src_v = ggml_view_1d(ctx, s_l[il], s_row_size, i_src * s_row_size); + ggml_tensor * dst_v = ggml_view_1d(ctx, s_l[il], s_row_size, i_dst * s_row_size); + ggml_backend_tensor_copy(src_v, dst_v); + ggml_free(ctx); + } + } +} + +int llama_memory_recurrent::get_cell_count(llama_seq_id seq_id) const { + int count = 0; + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].has_seq_id(seq_id)) { + count++; + } + } + return count; +} + std::map llama_memory_recurrent::memory_breakdown() const { std::map ret; for (const auto & [_, buf] : ctxs_bufs) { @@ -551,10 +643,35 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) { if (seq_meta.tail >= 0) { auto & orig_cell = cells[seq_meta.tail]; empty_cell.pos = orig_cell.pos; - empty_cell.src = orig_cell.src; - orig_cell.seq_id.erase(seq_id); + empty_cell.src = seq_meta.tail; // the data should be copied from the previous tail + + // Copy state data + copy_cell(seq_meta.tail, next_empty_cell); + + // Keep history of previous states for rollback (up to 8 cells per sequence) + if (get_cell_count(seq_id) < 8 && used < size * 0.9) { + // Do not erase seq_id from orig_cell to keep it as a checkpoint + } else { + // Erase oldest history point for this sequence + int32_t oldest_cell = -1; + llama_pos min_pos = std::numeric_limits::max(); + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].has_seq_id(seq_id) && cells[i].pos < min_pos) { + min_pos = cells[i].pos; + oldest_cell = i; + } + } + + if (oldest_cell >= 0) { + cells[oldest_cell].seq_id.erase(seq_id); + if (cells[oldest_cell].is_empty()) { + cells[oldest_cell].pos = -1; + cells[oldest_cell].src = -1; + used--; + } + } + } empty_cell.seq_id.insert(seq_id); // will be overwritten - GGML_ASSERT(!orig_cell.is_empty()); // has at least one remaining seq_id } seq_meta.tail = next_empty_cell; // find next empty cell diff --git a/src/llama-memory-recurrent.h b/src/llama-memory-recurrent.h index 47f01d7391..b6b5d6cfbd 100644 --- a/src/llama-memory-recurrent.h +++ b/src/llama-memory-recurrent.h @@ -65,6 +65,10 @@ public: void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; + // cell management + void copy_cell(int32_t i_src, int32_t i_dst); + int get_cell_count(llama_seq_id seq_id) const; + uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot()) uint32_t size = 0; // total number of cells, shared across all sequences uint32_t used = 0; // used cells (i.e. at least one seq_id) From 9a04ac4e10768e1786eb0c972da52cdf8168ffcd Mon Sep 17 00:00:00 2001 From: eauchs Date: Thu, 5 Mar 2026 11:02:13 +0100 Subject: [PATCH 2/4] fix: replace soft rollback with proper failure in recurrent seq_rm The soft rollback path (cells[tail_id].pos = p0 - 1) only updated position metadata, leaving SSM tensor state (r_l/s_l) reflecting the post-speculative position. This caused silent state corruption and looping on speculative decoding rejection for recurrent/hybrid models (e.g. Qwen3.5 MoE 27B). seq_rm now returns false when no checkpoint exists at p0-1, correctly signaling to the caller that rollback requires re-evaluation. The hybrid memory layer already propagates false correctly. Also add a LLAMA_LOG_DEBUG when the 0.9 cache threshold prevents checkpoint creation, making the behavior visible rather than silent. Co-Authored-By: Claude Sonnet 4.6 --- src/llama-memory-recurrent.cpp | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 4286d6ca2a..a1217a1fc2 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -175,9 +175,9 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos if (best_cell >= 0) { tail_id = best_cell; } else { - // if no checkpoint found, we still move the position back (soft rollback) - // only if it's the current sequence's tail - cells[tail_id].pos = p0 - 1; + // no checkpoint found at p0-1: the SSM tensor state cannot be rolled back + // without re-evaluating the sequence. Signal failure to the caller. + return false; } } // invalidate tails which will be cleared @@ -648,10 +648,17 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) { // Copy state data copy_cell(seq_meta.tail, next_empty_cell); - // Keep history of previous states for rollback (up to 8 cells per sequence) + // Keep history of previous states for rollback (up to 8 cells per sequence). + // The 0.9 threshold prevents the checkpoint history from filling the cache. + // When the cache is too full to keep checkpoints, speculative decoding rollback + // will fail (seq_rm returns false) and the caller must re-evaluate. if (get_cell_count(seq_id) < 8 && used < size * 0.9) { // Do not erase seq_id from orig_cell to keep it as a checkpoint } else { + if (used >= size * 0.9) { + LLAMA_LOG_DEBUG("%s: cache too full (used=%u/%u) to keep checkpoint for seq %d; speculative rollback will require re-evaluation\n", + __func__, used, size, seq_id); + } // Erase oldest history point for this sequence int32_t oldest_cell = -1; llama_pos min_pos = std::numeric_limits::max(); From 8fe0e0353e5b9cacf7fcf7dcb9639438bdf4cdb1 Mon Sep 17 00:00:00 2001 From: eauchs Date: Thu, 5 Mar 2026 11:11:38 +0100 Subject: [PATCH 3/4] Revert "fix: replace soft rollback with proper failure in recurrent seq_rm" This reverts commit 9a04ac4e10768e1786eb0c972da52cdf8168ffcd. --- src/llama-memory-recurrent.cpp | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index a1217a1fc2..4286d6ca2a 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -175,9 +175,9 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos if (best_cell >= 0) { tail_id = best_cell; } else { - // no checkpoint found at p0-1: the SSM tensor state cannot be rolled back - // without re-evaluating the sequence. Signal failure to the caller. - return false; + // if no checkpoint found, we still move the position back (soft rollback) + // only if it's the current sequence's tail + cells[tail_id].pos = p0 - 1; } } // invalidate tails which will be cleared @@ -648,17 +648,10 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) { // Copy state data copy_cell(seq_meta.tail, next_empty_cell); - // Keep history of previous states for rollback (up to 8 cells per sequence). - // The 0.9 threshold prevents the checkpoint history from filling the cache. - // When the cache is too full to keep checkpoints, speculative decoding rollback - // will fail (seq_rm returns false) and the caller must re-evaluate. + // Keep history of previous states for rollback (up to 8 cells per sequence) if (get_cell_count(seq_id) < 8 && used < size * 0.9) { // Do not erase seq_id from orig_cell to keep it as a checkpoint } else { - if (used >= size * 0.9) { - LLAMA_LOG_DEBUG("%s: cache too full (used=%u/%u) to keep checkpoint for seq %d; speculative rollback will require re-evaluation\n", - __func__, used, size, seq_id); - } // Erase oldest history point for this sequence int32_t oldest_cell = -1; llama_pos min_pos = std::numeric_limits::max(); From 8a6b1c860c78205e5a87449b5c2a8af578c73253 Mon Sep 17 00:00:00 2001 From: eauchs Date: Thu, 5 Mar 2026 11:15:14 +0100 Subject: [PATCH 4/4] fix: implement recurrent state checkpointing for has_cell=true path The checkpoint mechanism in find_slot only triggered when a sequence moved to a new cell (has_cell=false), which never occurs during normal single-sequence autoregressive generation. As a result, seq_rm had no checkpoint to roll back to during speculative decoding rejection. Fix: add checkpoint creation in the has_cell=true branch. Before the current cell is overwritten with new tokens, its SSM state (r_l/s_l) is copied to a free cell and kept as a checkpoint. This makes the rollback history available for the common single-sequence case. Also replace the soft rollback in seq_rm (which only rewound position metadata, leaving tensor state corrupted) with a proper return false, signaling to the caller that re-evaluation is required when no checkpoint exists at p0-1. Co-Authored-By: Claude Sonnet 4.6 --- src/llama-memory-recurrent.cpp | 51 ++++++++++++++++++++++++++++++++-- 1 file changed, 48 insertions(+), 3 deletions(-) diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 4286d6ca2a..4d62ad9e08 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -175,9 +175,9 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos if (best_cell >= 0) { tail_id = best_cell; } else { - // if no checkpoint found, we still move the position back (soft rollback) - // only if it's the current sequence's tail - cells[tail_id].pos = p0 - 1; + // No checkpoint found at p0-1: SSM tensor state cannot be rolled back + // without re-evaluating the sequence. Signal failure to the caller. + return false; } } // invalidate tails which will be cleared @@ -683,6 +683,51 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) { if (cell.is_empty()) { break; } } } + } else { + // Sequence owns its cell. Save a checkpoint of the current state before it is + // overwritten by new tokens. This is required for speculative decoding rollback + // in recurrent/SSM models where tensor state cannot be partially rewound. + const int32_t cur_tail = seq_meta.tail; + if (cells[next_empty_cell].is_empty()) { + bool can_checkpoint = (get_cell_count(seq_id) < 8 && used < size * 0.9); + if (!can_checkpoint) { + // Try to evict the oldest checkpoint to make room + int32_t oldest = -1; + llama_pos min_pos = std::numeric_limits::max(); + for (uint32_t j = 0; j < size; ++j) { + if ((int32_t)j != cur_tail && cells[j].has_seq_id(seq_id) && cells[j].pos < min_pos) { + min_pos = cells[j].pos; + oldest = j; + } + } + if (oldest >= 0) { + cells[oldest].seq_id.erase(seq_id); + if (cells[oldest].is_empty()) { + cells[oldest].pos = -1; + cells[oldest].src = -1; + used--; + } + can_checkpoint = true; + } + } + if (can_checkpoint) { + auto & cp_cell = cells[next_empty_cell]; + copy_cell(cur_tail, next_empty_cell); + cp_cell.pos = cells[cur_tail].pos; + cp_cell.src = next_empty_cell; // independent copy, no further movement needed + cp_cell.seq_id.insert(seq_id); + used++; + // advance next_empty_cell for subsequent sequences in this batch + if (s + 1 < n_seqs) { + for (uint32_t j = 0; j < size; ++j) { + next_empty_cell += 1; + if (next_empty_cell >= size) { next_empty_cell -= size; } + if (cells[next_empty_cell].is_empty()) { break; } + } + } + } + } + // seq_meta.tail remains unchanged - sequence still owns its current cell } if (min > seq_meta.tail) { min = seq_meta.tail; } if (max < seq_meta.tail) { max = seq_meta.tail; }