From 8fe0e0353e5b9cacf7fcf7dcb9639438bdf4cdb1 Mon Sep 17 00:00:00 2001 From: eauchs Date: Thu, 5 Mar 2026 11:11:38 +0100 Subject: [PATCH] Revert "fix: replace soft rollback with proper failure in recurrent seq_rm" This reverts commit 9a04ac4e10768e1786eb0c972da52cdf8168ffcd. --- src/llama-memory-recurrent.cpp | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index a1217a1fc2..4286d6ca2a 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -175,9 +175,9 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos if (best_cell >= 0) { tail_id = best_cell; } else { - // no checkpoint found at p0-1: the SSM tensor state cannot be rolled back - // without re-evaluating the sequence. Signal failure to the caller. - return false; + // if no checkpoint found, we still move the position back (soft rollback) + // only if it's the current sequence's tail + cells[tail_id].pos = p0 - 1; } } // invalidate tails which will be cleared @@ -648,17 +648,10 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) { // Copy state data copy_cell(seq_meta.tail, next_empty_cell); - // Keep history of previous states for rollback (up to 8 cells per sequence). - // The 0.9 threshold prevents the checkpoint history from filling the cache. - // When the cache is too full to keep checkpoints, speculative decoding rollback - // will fail (seq_rm returns false) and the caller must re-evaluate. + // Keep history of previous states for rollback (up to 8 cells per sequence) if (get_cell_count(seq_id) < 8 && used < size * 0.9) { // Do not erase seq_id from orig_cell to keep it as a checkpoint } else { - if (used >= size * 0.9) { - LLAMA_LOG_DEBUG("%s: cache too full (used=%u/%u) to keep checkpoint for seq %d; speculative rollback will require re-evaluation\n", - __func__, used, size, seq_id); - } // Erase oldest history point for this sequence int32_t oldest_cell = -1; llama_pos min_pos = std::numeric_limits::max();