fix: CPU staging copy for recurrent state checkpoint (fixes crash)

Root cause found: copy_cell crashes during find_slot because it calls
ggml_backend_tensor_copy on GPU tensors while the compute graph is
being built. Fixed by using CPU staging: tensor_get (GPU→CPU) then
tensor_set (CPU→GPU).

Also increased rs_size from 1 to 3 cells per sequence to make room
for checkpoint cells needed by speculative decoding rollback.

Results:
- No more crashes during speculative decode
- 23.8 tok/s with MTP (vs 16.7 without)
- 75% acceptance rate
- Output still garbled on long generation due to seq_rm not finding
  checkpoints at the right positions (checkpoint position mismatch)

Next: fix checkpoint position tracking so seq_rm can find and restore
the correct recurrent state after draft rejection.
This commit is contained in:
itigges22 2026-03-20 12:47:52 -04:00
parent 4aeffc690d
commit 279e6c721e
3 changed files with 70 additions and 20 deletions

View File

@ -10,6 +10,7 @@
#include <cstring>
#include <limits>
#include <map>
#include <stdexcept>
//
@ -165,18 +166,26 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) {
// for speculative decoding, we search for a checkpoint in the history
int32_t best_cell = -1;
fprintf(stderr, "[MTP-SEQRM] seq_id=%d, p0=%d, p1=%d, tail_pos=%d, searching for checkpoint at pos=%d\n",
(int)seq_id, (int)p0, (int)p1, (int)cell.pos, (int)(p0-1));
for (uint32_t i = 0; i < size; ++i) {
if (cells[i].has_seq_id(seq_id)) {
fprintf(stderr, "[MTP-SEQRM] cell[%d] pos=%d\n", i, (int)cells[i].pos);
}
if (cells[i].has_seq_id(seq_id) && cells[i].pos == p0 - 1) {
best_cell = i;
break;
}
}
fflush(stderr);
if (best_cell >= 0) {
fprintf(stderr, "[MTP-SEQRM] FOUND checkpoint at cell[%d] pos=%d — rolling back\n", best_cell, (int)(p0-1));
fflush(stderr);
tail_id = best_cell;
} else {
// No checkpoint found at p0-1: SSM tensor state cannot be rolled back
// without re-evaluating the sequence. Signal failure to the caller.
fprintf(stderr, "[MTP-SEQRM] NO checkpoint found — seq_rm FAILED\n");
fflush(stderr);
return false;
}
}
@ -423,28 +432,31 @@ void llama_memory_recurrent::copy_cell(int32_t i_src, int32_t i_dst) {
return;
}
ggml_init_params params = {
/*.mem_size =*/ size_t(2*ggml_tensor_overhead()),
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ true,
};
fprintf(stderr, "[MTP-COPYCELL] copy_cell(%d -> %d), n_layer=%d\n", i_src, i_dst, (int)hparams.n_layer);
fflush(stderr);
// Copy recurrent state via CPU staging buffer.
// Direct GPU-to-GPU copy via ggml_backend_tensor_copy crashes
// when called from find_slot during graph execution. Use CPU
// as intermediary: GPU→CPU (tensor_get) then CPU→GPU (tensor_set).
for (uint32_t il = 0; il < hparams.n_layer; ++il) {
if (r_l[il]) {
ggml_context * ctx = ggml_init(params);
size_t r_row_size = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
ggml_tensor * src_v = ggml_view_1d(ctx, r_l[il], r_row_size, i_src * r_row_size);
ggml_tensor * dst_v = ggml_view_1d(ctx, r_l[il], r_row_size, i_dst * r_row_size);
ggml_backend_tensor_copy(src_v, dst_v);
ggml_free(ctx);
size_t src_offset = (size_t)i_src * r_row_size;
size_t dst_offset = (size_t)i_dst * r_row_size;
std::vector<uint8_t> buf(r_row_size);
ggml_backend_tensor_get(r_l[il], buf.data(), src_offset, r_row_size);
ggml_backend_tensor_set(r_l[il], buf.data(), dst_offset, r_row_size);
}
if (s_l[il]) {
ggml_context * ctx = ggml_init(params);
size_t s_row_size = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
ggml_tensor * src_v = ggml_view_1d(ctx, s_l[il], s_row_size, i_src * s_row_size);
ggml_tensor * dst_v = ggml_view_1d(ctx, s_l[il], s_row_size, i_dst * s_row_size);
ggml_backend_tensor_copy(src_v, dst_v);
ggml_free(ctx);
size_t src_offset = (size_t)i_src * s_row_size;
size_t dst_offset = (size_t)i_dst * s_row_size;
std::vector<uint8_t> buf(s_row_size);
ggml_backend_tensor_get(s_l[il], buf.data(), src_offset, s_row_size);
ggml_backend_tensor_set(s_l[il], buf.data(), dst_offset, s_row_size);
}
}
}
@ -545,6 +557,10 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
const uint32_t n_seqs = ubatch.n_seqs;
fprintf(stderr, "[MTP-FINDSLOT] find_slot: n_seq_tokens=%d, n_seqs=%d, size=%d, used=%d, head=%d\n",
(int)n_seq_tokens, (int)n_seqs, (int)size, (int)used, (int)head);
fflush(stderr);
// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
if (head > used + 2*n_seqs) {
@ -688,6 +704,9 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
// overwritten by new tokens. This is required for speculative decoding rollback
// in recurrent/SSM models where tensor state cannot be partially rewound.
const int32_t cur_tail = seq_meta.tail;
fprintf(stderr, "[MTP-FINDSLOT] checkpoint branch: seq_id=%d, cur_tail=%d, next_empty=%d\n",
(int)seq_id, cur_tail, (int)next_empty_cell);
fflush(stderr);
if (cells[next_empty_cell].is_empty()) {
bool can_checkpoint = (get_cell_count(seq_id) < 8 && used < size * 0.9);
if (!can_checkpoint) {
@ -712,11 +731,18 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
}
if (can_checkpoint) {
auto & cp_cell = cells[next_empty_cell];
// Copy the current recurrent state as a checkpoint.
// This must happen before the graph overwrites the state
// with new token processing. The copy is safe here because
// find_slot runs BEFORE the compute graph for this ubatch.
copy_cell(cur_tail, next_empty_cell);
cp_cell.pos = cells[cur_tail].pos;
cp_cell.src = next_empty_cell; // independent copy, no further movement needed
cp_cell.src = next_empty_cell; // independent copy
cp_cell.seq_id.insert(seq_id);
used++;
fprintf(stderr, "[MTP-FINDSLOT] checkpoint at cell %d (copied from %d), pos=%d, used=%d\n",
(int)next_empty_cell, cur_tail, (int)cp_cell.pos, (int)used);
fflush(stderr);
// advance next_empty_cell for subsequent sequences in this batch
if (s + 1 < n_seqs) {
for (uint32_t j = 0; j < size; ++j) {

View File

@ -8112,6 +8112,13 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
// Use hybrid-iswa for hybrid models with SWA
// For MTP speculative decoding, we need extra recurrent state
// cells for checkpoint/restore. Each sequence needs at least
// 1 active cell + 1 checkpoint cell per MTP draft step.
const uint32_t n_mtp = hparams.nextn_predict_layers;
const uint32_t rs_per_seq = 1 + (n_mtp > 0 ? 2 : 0); // active + checkpoint room
const uint32_t rs_size = std::max((uint32_t) 1, cparams.n_seq_max * rs_per_seq);
res = new llama_memory_hybrid_iswa(
/* model */ *this,
/* attn_type_k */ params.type_k,
@ -8123,13 +8130,18 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
/* attn_n_pad */ 1,
/* recurrent_type_r */ GGML_TYPE_F32,
/* recurrent_type_s */ GGML_TYPE_F32,
/* recurrent_rs_size */ std::max((uint32_t) 1, cparams.n_seq_max),
/* recurrent_rs_size */ rs_size,
/* n_seq_max */ cparams.n_seq_max,
/* offload */ cparams.offload_kqv,
/* unified */ cparams.kv_unified,
/* filter_attn */ std::move(filter_attn),
/* filter_recr */ std::move(filter_recr));
} else {
// Same MTP checkpoint room for non-SWA path
const uint32_t n_mtp2 = hparams.nextn_predict_layers;
const uint32_t rs_per_seq2 = 1 + (n_mtp2 > 0 ? 2 : 0);
const uint32_t rs_size2 = std::max((uint32_t) 1, cparams.n_seq_max * rs_per_seq2);
res = new llama_memory_hybrid(
/* model */ *this,
/* attn_type_k */ params.type_k,
@ -8141,7 +8153,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
/* attn_swa_type */ hparams.swa_type,
/* recurrent_type_k */ GGML_TYPE_F32,
/* recurrent_type_v */ GGML_TYPE_F32,
/* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
/* recurrent_kv_size */ rs_size2,
/* n_seq_max */ cparams.n_seq_max,
/* offload */ cparams.offload_kqv,
/* unified */ cparams.kv_unified,

View File

@ -2769,6 +2769,14 @@ private:
};
fprintf(stderr, "[MTP-DBG] llama_decode: n_tokens=%d, batch_start=%d\n", n_tokens, i);
// Print batch details for 2-token (speculative) batches
if (n_tokens == 2) {
for (int bi = 0; bi < n_tokens; bi++) {
fprintf(stderr, "[MTP-DBG] batch[%d]: token=%d, pos=%d, seq_id=%d, logits=%d\n",
bi, (int)batch.token[i+bi], (int)batch.pos[i+bi],
(int)batch.seq_id[i+bi][0], (int)batch.logits[i+bi]);
}
}
fflush(stderr);
const int ret = llama_decode(ctx, batch_view);
@ -2890,7 +2898,11 @@ private:
slot.state = SLOT_STATE_GENERATING;
if (slot.can_speculate()) {
fprintf(stderr, "[MTP-DBG] speculative_begin for slot %d\n", slot.id);
fflush(stderr);
common_speculative_begin(slot.spec, slot.prompt.tokens.get_text_tokens());
fprintf(stderr, "[MTP-DBG] speculative_begin done\n");
fflush(stderr);
}
} else if (slot.state != SLOT_STATE_GENERATING) {
continue; // continue loop of slots