This commit is contained in:
Théophile Lafargue 2026-03-15 22:49:42 +01:00 committed by GitHub
commit 96edb94072
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 188 additions and 22 deletions

View File

@ -163,12 +163,41 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
const auto & cell = cells[tail_id];
// partial intersection is invalid if it includes the final pos
if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) {
//printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false, p0 = %d, cell.pos = %d, p1 = %d\n", p0, cell.pos, p1);
return false;
// for speculative decoding, we search for a checkpoint in the history
int32_t best_cell = -1;
for (uint32_t i = 0; i < size; ++i) {
if (cells[i].has_seq_id(seq_id) && cells[i].pos == p0 - 1) {
best_cell = i;
break;
}
}
if (best_cell >= 0) {
tail_id = best_cell;
} else {
// No checkpoint found at p0-1: SSM tensor state cannot be rolled back
// without re-evaluating the sequence. Signal failure to the caller.
return false;
}
}
// invalidate tails which will be cleared
if (p0 <= cell.pos && cell.pos < p1) {
tail_id = -1;
if (p0 == 0) {
tail_id = -1;
} else {
// Search for the best remaining cell after removal
int32_t new_tail = -1;
llama_pos max_pos = -1;
for (uint32_t i = 0; i < size; ++i) {
if (cells[i].has_seq_id(seq_id) && cells[i].pos < p0) {
if (cells[i].pos > max_pos) {
max_pos = cells[i].pos;
new_tail = i;
}
}
}
tail_id = new_tail;
}
}
}
} else {
@ -184,6 +213,11 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
if (seq_id < 0) {
cells[i].seq_id.clear();
} else if (cells[i].has_seq_id(seq_id)) {
if (p0 > 0 && p1 == std::numeric_limits<llama_pos>::max()) {
// partial removal: just move the position back
cells[i].pos = p0 - 1;
continue;
}
cells[i].seq_id.erase(seq_id);
} else {
continue;
@ -224,25 +258,42 @@ void llama_memory_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id
}
if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
auto & tail_src = cells[seq_id_src];
auto & tail_dst = cells[seq_id_dst];
if (tail_dst.tail >= 0) {
auto & tail_src_meta = cells[seq_id_src];
auto & tail_dst_meta = cells[seq_id_dst];
if (tail_dst_meta.tail >= 0) {
// clear destination seq_id if it wasn't empty
auto & cell_dst = cells[tail_dst.tail];
cell_dst.seq_id.erase(seq_id_dst);
tail_dst.tail = -1;
if (cell_dst.seq_id.empty()) {
cell_dst.pos = -1;
cell_dst.src = -1;
used -= 1;
}
seq_rm(seq_id_dst, -1, -1);
}
if (tail_src.tail >= 0) {
auto & cell_src = cells[tail_src.tail];
cell_src.seq_id.insert(seq_id_dst);
tail_dst.tail = tail_src.tail;
if (tail_src_meta.tail >= 0) {
auto & cell_src = cells[tail_src_meta.tail];
// For recurrent models, we must copy the state to a new cell
// Otherwise, both sequences would share the same mutable state
uint32_t next_empty_cell = size;
for (uint32_t i = head; i < head + size; ++i) {
uint32_t idx = i % size;
if (cells[idx].is_empty()) {
next_empty_cell = idx;
break;
}
}
if (next_empty_cell != size) {
auto & empty_cell = cells[next_empty_cell];
// Copy tensors data
copy_cell(tail_src_meta.tail, next_empty_cell);
empty_cell.pos = cell_src.pos;
empty_cell.src = next_empty_cell; // results in a copy in the graph if needed
empty_cell.seq_id.insert(seq_id_dst);
tail_dst_meta.tail = next_empty_cell;
used += 1;
} else {
LLAMA_LOG_ERROR("%s: failed to find available cell for copy\n", __func__);
}
}
}
}
@ -367,6 +418,47 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
return result;
}
void llama_memory_recurrent::copy_cell(int32_t i_src, int32_t i_dst) {
if (i_src == i_dst || i_src < 0 || i_dst < 0) {
return;
}
ggml_init_params params = {
/*.mem_size =*/ size_t(2*ggml_tensor_overhead()),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
for (uint32_t il = 0; il < hparams.n_layer; ++il) {
if (r_l[il]) {
ggml_context * ctx = ggml_init(params);
size_t r_row_size = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
ggml_tensor * src_v = ggml_view_1d(ctx, r_l[il], r_row_size, i_src * r_row_size);
ggml_tensor * dst_v = ggml_view_1d(ctx, r_l[il], r_row_size, i_dst * r_row_size);
ggml_backend_tensor_copy(src_v, dst_v);
ggml_free(ctx);
}
if (s_l[il]) {
ggml_context * ctx = ggml_init(params);
size_t s_row_size = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
ggml_tensor * src_v = ggml_view_1d(ctx, s_l[il], s_row_size, i_src * s_row_size);
ggml_tensor * dst_v = ggml_view_1d(ctx, s_l[il], s_row_size, i_dst * s_row_size);
ggml_backend_tensor_copy(src_v, dst_v);
ggml_free(ctx);
}
}
}
int llama_memory_recurrent::get_cell_count(llama_seq_id seq_id) const {
int count = 0;
for (uint32_t i = 0; i < size; ++i) {
if (cells[i].has_seq_id(seq_id)) {
count++;
}
}
return count;
}
std::map<ggml_backend_buffer_type_t, size_t> llama_memory_recurrent::memory_breakdown() const {
std::map<ggml_backend_buffer_type_t, size_t> ret;
for (const auto & [_, buf] : ctxs_bufs) {
@ -551,10 +643,35 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
if (seq_meta.tail >= 0) {
auto & orig_cell = cells[seq_meta.tail];
empty_cell.pos = orig_cell.pos;
empty_cell.src = orig_cell.src;
orig_cell.seq_id.erase(seq_id);
empty_cell.src = seq_meta.tail; // the data should be copied from the previous tail
// Copy state data
copy_cell(seq_meta.tail, next_empty_cell);
// Keep history of previous states for rollback (up to 8 cells per sequence)
if (get_cell_count(seq_id) < 8 && used < size * 0.9) {
// Do not erase seq_id from orig_cell to keep it as a checkpoint
} else {
// Erase oldest history point for this sequence
int32_t oldest_cell = -1;
llama_pos min_pos = std::numeric_limits<llama_pos>::max();
for (uint32_t i = 0; i < size; ++i) {
if (cells[i].has_seq_id(seq_id) && cells[i].pos < min_pos) {
min_pos = cells[i].pos;
oldest_cell = i;
}
}
if (oldest_cell >= 0) {
cells[oldest_cell].seq_id.erase(seq_id);
if (cells[oldest_cell].is_empty()) {
cells[oldest_cell].pos = -1;
cells[oldest_cell].src = -1;
used--;
}
}
}
empty_cell.seq_id.insert(seq_id); // will be overwritten
GGML_ASSERT(!orig_cell.is_empty()); // has at least one remaining seq_id
}
seq_meta.tail = next_empty_cell;
// find next empty cell
@ -566,6 +683,51 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
if (cell.is_empty()) { break; }
}
}
} else {
// Sequence owns its cell. Save a checkpoint of the current state before it is
// overwritten by new tokens. This is required for speculative decoding rollback
// in recurrent/SSM models where tensor state cannot be partially rewound.
const int32_t cur_tail = seq_meta.tail;
if (cells[next_empty_cell].is_empty()) {
bool can_checkpoint = (get_cell_count(seq_id) < 8 && used < size * 0.9);
if (!can_checkpoint) {
// Try to evict the oldest checkpoint to make room
int32_t oldest = -1;
llama_pos min_pos = std::numeric_limits<llama_pos>::max();
for (uint32_t j = 0; j < size; ++j) {
if ((int32_t)j != cur_tail && cells[j].has_seq_id(seq_id) && cells[j].pos < min_pos) {
min_pos = cells[j].pos;
oldest = j;
}
}
if (oldest >= 0) {
cells[oldest].seq_id.erase(seq_id);
if (cells[oldest].is_empty()) {
cells[oldest].pos = -1;
cells[oldest].src = -1;
used--;
}
can_checkpoint = true;
}
}
if (can_checkpoint) {
auto & cp_cell = cells[next_empty_cell];
copy_cell(cur_tail, next_empty_cell);
cp_cell.pos = cells[cur_tail].pos;
cp_cell.src = next_empty_cell; // independent copy, no further movement needed
cp_cell.seq_id.insert(seq_id);
used++;
// advance next_empty_cell for subsequent sequences in this batch
if (s + 1 < n_seqs) {
for (uint32_t j = 0; j < size; ++j) {
next_empty_cell += 1;
if (next_empty_cell >= size) { next_empty_cell -= size; }
if (cells[next_empty_cell].is_empty()) { break; }
}
}
}
}
// seq_meta.tail remains unchanged - sequence still owns its current cell
}
if (min > seq_meta.tail) { min = seq_meta.tail; }
if (max < seq_meta.tail) { max = seq_meta.tail; }

View File

@ -65,6 +65,10 @@ public:
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
// cell management
void copy_cell(int32_t i_src, int32_t i_dst);
int get_cell_count(llama_seq_id seq_id) const;
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
uint32_t size = 0; // total number of cells, shared across all sequences
uint32_t used = 0; // used cells (i.e. at least one seq_id)