fix: implement recurrent state checkpointing for has_cell=true path

The checkpoint mechanism in find_slot only triggered when a sequence
moved to a new cell (has_cell=false), which never occurs during normal
single-sequence autoregressive generation. As a result, seq_rm had no
checkpoint to roll back to during speculative decoding rejection.

Fix: add checkpoint creation in the has_cell=true branch. Before the
current cell is overwritten with new tokens, its SSM state (r_l/s_l)
is copied to a free cell and kept as a checkpoint. This makes the
rollback history available for the common single-sequence case.

Also replace the soft rollback in seq_rm (which only rewound position
metadata, leaving tensor state corrupted) with a proper return false,
signaling to the caller that re-evaluation is required when no
checkpoint exists at p0-1.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
eauchs 2026-03-05 11:15:14 +01:00
parent 8fe0e0353e
commit 8a6b1c860c
1 changed files with 48 additions and 3 deletions

View File

@ -175,9 +175,9 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
if (best_cell >= 0) {
tail_id = best_cell;
} else {
// if no checkpoint found, we still move the position back (soft rollback)
// only if it's the current sequence's tail
cells[tail_id].pos = p0 - 1;
// No checkpoint found at p0-1: SSM tensor state cannot be rolled back
// without re-evaluating the sequence. Signal failure to the caller.
return false;
}
}
// invalidate tails which will be cleared
@ -683,6 +683,51 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
if (cell.is_empty()) { break; }
}
}
} else {
// Sequence owns its cell. Save a checkpoint of the current state before it is
// overwritten by new tokens. This is required for speculative decoding rollback
// in recurrent/SSM models where tensor state cannot be partially rewound.
const int32_t cur_tail = seq_meta.tail;
if (cells[next_empty_cell].is_empty()) {
bool can_checkpoint = (get_cell_count(seq_id) < 8 && used < size * 0.9);
if (!can_checkpoint) {
// Try to evict the oldest checkpoint to make room
int32_t oldest = -1;
llama_pos min_pos = std::numeric_limits<llama_pos>::max();
for (uint32_t j = 0; j < size; ++j) {
if ((int32_t)j != cur_tail && cells[j].has_seq_id(seq_id) && cells[j].pos < min_pos) {
min_pos = cells[j].pos;
oldest = j;
}
}
if (oldest >= 0) {
cells[oldest].seq_id.erase(seq_id);
if (cells[oldest].is_empty()) {
cells[oldest].pos = -1;
cells[oldest].src = -1;
used--;
}
can_checkpoint = true;
}
}
if (can_checkpoint) {
auto & cp_cell = cells[next_empty_cell];
copy_cell(cur_tail, next_empty_cell);
cp_cell.pos = cells[cur_tail].pos;
cp_cell.src = next_empty_cell; // independent copy, no further movement needed
cp_cell.seq_id.insert(seq_id);
used++;
// advance next_empty_cell for subsequent sequences in this batch
if (s + 1 < n_seqs) {
for (uint32_t j = 0; j < size; ++j) {
next_empty_cell += 1;
if (next_empty_cell >= size) { next_empty_cell -= size; }
if (cells[next_empty_cell].is_empty()) { break; }
}
}
}
}
// seq_meta.tail remains unchanged - sequence still owns its current cell
}
if (min > seq_meta.tail) { min = seq_meta.tail; }
if (max < seq_meta.tail) { max = seq_meta.tail; }