Merge 8a6b1c860c into 9e2e2198b0
This commit is contained in:
commit
96edb94072
|
|
@ -163,12 +163,41 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
|||
const auto & cell = cells[tail_id];
|
||||
// partial intersection is invalid if it includes the final pos
|
||||
if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) {
|
||||
//printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false, p0 = %d, cell.pos = %d, p1 = %d\n", p0, cell.pos, p1);
|
||||
return false;
|
||||
// for speculative decoding, we search for a checkpoint in the history
|
||||
int32_t best_cell = -1;
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
if (cells[i].has_seq_id(seq_id) && cells[i].pos == p0 - 1) {
|
||||
best_cell = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (best_cell >= 0) {
|
||||
tail_id = best_cell;
|
||||
} else {
|
||||
// No checkpoint found at p0-1: SSM tensor state cannot be rolled back
|
||||
// without re-evaluating the sequence. Signal failure to the caller.
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// invalidate tails which will be cleared
|
||||
if (p0 <= cell.pos && cell.pos < p1) {
|
||||
tail_id = -1;
|
||||
if (p0 == 0) {
|
||||
tail_id = -1;
|
||||
} else {
|
||||
// Search for the best remaining cell after removal
|
||||
int32_t new_tail = -1;
|
||||
llama_pos max_pos = -1;
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
if (cells[i].has_seq_id(seq_id) && cells[i].pos < p0) {
|
||||
if (cells[i].pos > max_pos) {
|
||||
max_pos = cells[i].pos;
|
||||
new_tail = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
tail_id = new_tail;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
|
@ -184,6 +213,11 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
|||
if (seq_id < 0) {
|
||||
cells[i].seq_id.clear();
|
||||
} else if (cells[i].has_seq_id(seq_id)) {
|
||||
if (p0 > 0 && p1 == std::numeric_limits<llama_pos>::max()) {
|
||||
// partial removal: just move the position back
|
||||
cells[i].pos = p0 - 1;
|
||||
continue;
|
||||
}
|
||||
cells[i].seq_id.erase(seq_id);
|
||||
} else {
|
||||
continue;
|
||||
|
|
@ -224,25 +258,42 @@ void llama_memory_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id
|
|||
}
|
||||
|
||||
if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
|
||||
auto & tail_src = cells[seq_id_src];
|
||||
auto & tail_dst = cells[seq_id_dst];
|
||||
if (tail_dst.tail >= 0) {
|
||||
auto & tail_src_meta = cells[seq_id_src];
|
||||
auto & tail_dst_meta = cells[seq_id_dst];
|
||||
|
||||
if (tail_dst_meta.tail >= 0) {
|
||||
// clear destination seq_id if it wasn't empty
|
||||
auto & cell_dst = cells[tail_dst.tail];
|
||||
|
||||
cell_dst.seq_id.erase(seq_id_dst);
|
||||
tail_dst.tail = -1;
|
||||
if (cell_dst.seq_id.empty()) {
|
||||
cell_dst.pos = -1;
|
||||
cell_dst.src = -1;
|
||||
used -= 1;
|
||||
}
|
||||
seq_rm(seq_id_dst, -1, -1);
|
||||
}
|
||||
if (tail_src.tail >= 0) {
|
||||
auto & cell_src = cells[tail_src.tail];
|
||||
|
||||
cell_src.seq_id.insert(seq_id_dst);
|
||||
tail_dst.tail = tail_src.tail;
|
||||
if (tail_src_meta.tail >= 0) {
|
||||
auto & cell_src = cells[tail_src_meta.tail];
|
||||
|
||||
// For recurrent models, we must copy the state to a new cell
|
||||
// Otherwise, both sequences would share the same mutable state
|
||||
uint32_t next_empty_cell = size;
|
||||
for (uint32_t i = head; i < head + size; ++i) {
|
||||
uint32_t idx = i % size;
|
||||
if (cells[idx].is_empty()) {
|
||||
next_empty_cell = idx;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (next_empty_cell != size) {
|
||||
auto & empty_cell = cells[next_empty_cell];
|
||||
|
||||
// Copy tensors data
|
||||
copy_cell(tail_src_meta.tail, next_empty_cell);
|
||||
|
||||
empty_cell.pos = cell_src.pos;
|
||||
empty_cell.src = next_empty_cell; // results in a copy in the graph if needed
|
||||
empty_cell.seq_id.insert(seq_id_dst);
|
||||
tail_dst_meta.tail = next_empty_cell;
|
||||
used += 1;
|
||||
} else {
|
||||
LLAMA_LOG_ERROR("%s: failed to find available cell for copy\n", __func__);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -367,6 +418,47 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
|||
return result;
|
||||
}
|
||||
|
||||
void llama_memory_recurrent::copy_cell(int32_t i_src, int32_t i_dst) {
|
||||
if (i_src == i_dst || i_src < 0 || i_dst < 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/ size_t(2*ggml_tensor_overhead()),
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
for (uint32_t il = 0; il < hparams.n_layer; ++il) {
|
||||
if (r_l[il]) {
|
||||
ggml_context * ctx = ggml_init(params);
|
||||
size_t r_row_size = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
|
||||
ggml_tensor * src_v = ggml_view_1d(ctx, r_l[il], r_row_size, i_src * r_row_size);
|
||||
ggml_tensor * dst_v = ggml_view_1d(ctx, r_l[il], r_row_size, i_dst * r_row_size);
|
||||
ggml_backend_tensor_copy(src_v, dst_v);
|
||||
ggml_free(ctx);
|
||||
}
|
||||
if (s_l[il]) {
|
||||
ggml_context * ctx = ggml_init(params);
|
||||
size_t s_row_size = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
|
||||
ggml_tensor * src_v = ggml_view_1d(ctx, s_l[il], s_row_size, i_src * s_row_size);
|
||||
ggml_tensor * dst_v = ggml_view_1d(ctx, s_l[il], s_row_size, i_dst * s_row_size);
|
||||
ggml_backend_tensor_copy(src_v, dst_v);
|
||||
ggml_free(ctx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int llama_memory_recurrent::get_cell_count(llama_seq_id seq_id) const {
|
||||
int count = 0;
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
if (cells[i].has_seq_id(seq_id)) {
|
||||
count++;
|
||||
}
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
std::map<ggml_backend_buffer_type_t, size_t> llama_memory_recurrent::memory_breakdown() const {
|
||||
std::map<ggml_backend_buffer_type_t, size_t> ret;
|
||||
for (const auto & [_, buf] : ctxs_bufs) {
|
||||
|
|
@ -551,10 +643,35 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|||
if (seq_meta.tail >= 0) {
|
||||
auto & orig_cell = cells[seq_meta.tail];
|
||||
empty_cell.pos = orig_cell.pos;
|
||||
empty_cell.src = orig_cell.src;
|
||||
orig_cell.seq_id.erase(seq_id);
|
||||
empty_cell.src = seq_meta.tail; // the data should be copied from the previous tail
|
||||
|
||||
// Copy state data
|
||||
copy_cell(seq_meta.tail, next_empty_cell);
|
||||
|
||||
// Keep history of previous states for rollback (up to 8 cells per sequence)
|
||||
if (get_cell_count(seq_id) < 8 && used < size * 0.9) {
|
||||
// Do not erase seq_id from orig_cell to keep it as a checkpoint
|
||||
} else {
|
||||
// Erase oldest history point for this sequence
|
||||
int32_t oldest_cell = -1;
|
||||
llama_pos min_pos = std::numeric_limits<llama_pos>::max();
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
if (cells[i].has_seq_id(seq_id) && cells[i].pos < min_pos) {
|
||||
min_pos = cells[i].pos;
|
||||
oldest_cell = i;
|
||||
}
|
||||
}
|
||||
|
||||
if (oldest_cell >= 0) {
|
||||
cells[oldest_cell].seq_id.erase(seq_id);
|
||||
if (cells[oldest_cell].is_empty()) {
|
||||
cells[oldest_cell].pos = -1;
|
||||
cells[oldest_cell].src = -1;
|
||||
used--;
|
||||
}
|
||||
}
|
||||
}
|
||||
empty_cell.seq_id.insert(seq_id); // will be overwritten
|
||||
GGML_ASSERT(!orig_cell.is_empty()); // has at least one remaining seq_id
|
||||
}
|
||||
seq_meta.tail = next_empty_cell;
|
||||
// find next empty cell
|
||||
|
|
@ -566,6 +683,51 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|||
if (cell.is_empty()) { break; }
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Sequence owns its cell. Save a checkpoint of the current state before it is
|
||||
// overwritten by new tokens. This is required for speculative decoding rollback
|
||||
// in recurrent/SSM models where tensor state cannot be partially rewound.
|
||||
const int32_t cur_tail = seq_meta.tail;
|
||||
if (cells[next_empty_cell].is_empty()) {
|
||||
bool can_checkpoint = (get_cell_count(seq_id) < 8 && used < size * 0.9);
|
||||
if (!can_checkpoint) {
|
||||
// Try to evict the oldest checkpoint to make room
|
||||
int32_t oldest = -1;
|
||||
llama_pos min_pos = std::numeric_limits<llama_pos>::max();
|
||||
for (uint32_t j = 0; j < size; ++j) {
|
||||
if ((int32_t)j != cur_tail && cells[j].has_seq_id(seq_id) && cells[j].pos < min_pos) {
|
||||
min_pos = cells[j].pos;
|
||||
oldest = j;
|
||||
}
|
||||
}
|
||||
if (oldest >= 0) {
|
||||
cells[oldest].seq_id.erase(seq_id);
|
||||
if (cells[oldest].is_empty()) {
|
||||
cells[oldest].pos = -1;
|
||||
cells[oldest].src = -1;
|
||||
used--;
|
||||
}
|
||||
can_checkpoint = true;
|
||||
}
|
||||
}
|
||||
if (can_checkpoint) {
|
||||
auto & cp_cell = cells[next_empty_cell];
|
||||
copy_cell(cur_tail, next_empty_cell);
|
||||
cp_cell.pos = cells[cur_tail].pos;
|
||||
cp_cell.src = next_empty_cell; // independent copy, no further movement needed
|
||||
cp_cell.seq_id.insert(seq_id);
|
||||
used++;
|
||||
// advance next_empty_cell for subsequent sequences in this batch
|
||||
if (s + 1 < n_seqs) {
|
||||
for (uint32_t j = 0; j < size; ++j) {
|
||||
next_empty_cell += 1;
|
||||
if (next_empty_cell >= size) { next_empty_cell -= size; }
|
||||
if (cells[next_empty_cell].is_empty()) { break; }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// seq_meta.tail remains unchanged - sequence still owns its current cell
|
||||
}
|
||||
if (min > seq_meta.tail) { min = seq_meta.tail; }
|
||||
if (max < seq_meta.tail) { max = seq_meta.tail; }
|
||||
|
|
|
|||
|
|
@ -65,6 +65,10 @@ public:
|
|||
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
|
||||
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
|
||||
|
||||
// cell management
|
||||
void copy_cell(int32_t i_src, int32_t i_dst);
|
||||
int get_cell_count(llama_seq_id seq_id) const;
|
||||
|
||||
uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
|
||||
uint32_t size = 0; // total number of cells, shared across all sequences
|
||||
uint32_t used = 0; // used cells (i.e. at least one seq_id)
|
||||
|
|
|
|||
Loading…
Reference in New Issue