From 8a6b1c860c78205e5a87449b5c2a8af578c73253 Mon Sep 17 00:00:00 2001 From: eauchs Date: Thu, 5 Mar 2026 11:15:14 +0100 Subject: [PATCH] fix: implement recurrent state checkpointing for has_cell=true path The checkpoint mechanism in find_slot only triggered when a sequence moved to a new cell (has_cell=false), which never occurs during normal single-sequence autoregressive generation. As a result, seq_rm had no checkpoint to roll back to during speculative decoding rejection. Fix: add checkpoint creation in the has_cell=true branch. Before the current cell is overwritten with new tokens, its SSM state (r_l/s_l) is copied to a free cell and kept as a checkpoint. This makes the rollback history available for the common single-sequence case. Also replace the soft rollback in seq_rm (which only rewound position metadata, leaving tensor state corrupted) with a proper return false, signaling to the caller that re-evaluation is required when no checkpoint exists at p0-1. Co-Authored-By: Claude Sonnet 4.6 --- src/llama-memory-recurrent.cpp | 51 ++++++++++++++++++++++++++++++++-- 1 file changed, 48 insertions(+), 3 deletions(-) diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 4286d6ca2a..4d62ad9e08 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -175,9 +175,9 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos if (best_cell >= 0) { tail_id = best_cell; } else { - // if no checkpoint found, we still move the position back (soft rollback) - // only if it's the current sequence's tail - cells[tail_id].pos = p0 - 1; + // No checkpoint found at p0-1: SSM tensor state cannot be rolled back + // without re-evaluating the sequence. Signal failure to the caller. + return false; } } // invalidate tails which will be cleared @@ -683,6 +683,51 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) { if (cell.is_empty()) { break; } } } + } else { + // Sequence owns its cell. Save a checkpoint of the current state before it is + // overwritten by new tokens. This is required for speculative decoding rollback + // in recurrent/SSM models where tensor state cannot be partially rewound. + const int32_t cur_tail = seq_meta.tail; + if (cells[next_empty_cell].is_empty()) { + bool can_checkpoint = (get_cell_count(seq_id) < 8 && used < size * 0.9); + if (!can_checkpoint) { + // Try to evict the oldest checkpoint to make room + int32_t oldest = -1; + llama_pos min_pos = std::numeric_limits::max(); + for (uint32_t j = 0; j < size; ++j) { + if ((int32_t)j != cur_tail && cells[j].has_seq_id(seq_id) && cells[j].pos < min_pos) { + min_pos = cells[j].pos; + oldest = j; + } + } + if (oldest >= 0) { + cells[oldest].seq_id.erase(seq_id); + if (cells[oldest].is_empty()) { + cells[oldest].pos = -1; + cells[oldest].src = -1; + used--; + } + can_checkpoint = true; + } + } + if (can_checkpoint) { + auto & cp_cell = cells[next_empty_cell]; + copy_cell(cur_tail, next_empty_cell); + cp_cell.pos = cells[cur_tail].pos; + cp_cell.src = next_empty_cell; // independent copy, no further movement needed + cp_cell.seq_id.insert(seq_id); + used++; + // advance next_empty_cell for subsequent sequences in this batch + if (s + 1 < n_seqs) { + for (uint32_t j = 0; j < size; ++j) { + next_empty_cell += 1; + if (next_empty_cell >= size) { next_empty_cell -= size; } + if (cells[next_empty_cell].is_empty()) { break; } + } + } + } + } + // seq_meta.tail remains unchanged - sequence still owns its current cell } if (min > seq_meta.tail) { min = seq_meta.tail; } if (max < seq_meta.tail) { max = seq_meta.tail; }