fix: CPU staging copy for recurrent state checkpoint (fixes crash)
Root cause found: copy_cell crashes during find_slot because it calls ggml_backend_tensor_copy on GPU tensors while the compute graph is being built. Fixed by using CPU staging: tensor_get (GPU→CPU) then tensor_set (CPU→GPU). Also increased rs_size from 1 to 3 cells per sequence to make room for checkpoint cells needed by speculative decoding rollback. Results: - No more crashes during speculative decode - 23.8 tok/s with MTP (vs 16.7 without) - 75% acceptance rate - Output still garbled on long generation due to seq_rm not finding checkpoints at the right positions (checkpoint position mismatch) Next: fix checkpoint position tracking so seq_rm can find and restore the correct recurrent state after draft rejection.
This commit is contained in:
parent
4aeffc690d
commit
279e6c721e
|
|
@ -10,6 +10,7 @@
|
|||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
|
||||
#include <stdexcept>
|
||||
|
||||
//
|
||||
|
|
@ -165,18 +166,26 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
|||
if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) {
|
||||
// for speculative decoding, we search for a checkpoint in the history
|
||||
int32_t best_cell = -1;
|
||||
fprintf(stderr, "[MTP-SEQRM] seq_id=%d, p0=%d, p1=%d, tail_pos=%d, searching for checkpoint at pos=%d\n",
|
||||
(int)seq_id, (int)p0, (int)p1, (int)cell.pos, (int)(p0-1));
|
||||
for (uint32_t i = 0; i < size; ++i) {
|
||||
if (cells[i].has_seq_id(seq_id)) {
|
||||
fprintf(stderr, "[MTP-SEQRM] cell[%d] pos=%d\n", i, (int)cells[i].pos);
|
||||
}
|
||||
if (cells[i].has_seq_id(seq_id) && cells[i].pos == p0 - 1) {
|
||||
best_cell = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
fflush(stderr);
|
||||
|
||||
if (best_cell >= 0) {
|
||||
fprintf(stderr, "[MTP-SEQRM] FOUND checkpoint at cell[%d] pos=%d — rolling back\n", best_cell, (int)(p0-1));
|
||||
fflush(stderr);
|
||||
tail_id = best_cell;
|
||||
} else {
|
||||
// No checkpoint found at p0-1: SSM tensor state cannot be rolled back
|
||||
// without re-evaluating the sequence. Signal failure to the caller.
|
||||
fprintf(stderr, "[MTP-SEQRM] NO checkpoint found — seq_rm FAILED\n");
|
||||
fflush(stderr);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
@ -423,28 +432,31 @@ void llama_memory_recurrent::copy_cell(int32_t i_src, int32_t i_dst) {
|
|||
return;
|
||||
}
|
||||
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/ size_t(2*ggml_tensor_overhead()),
|
||||
/*.mem_buffer =*/ NULL,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
fprintf(stderr, "[MTP-COPYCELL] copy_cell(%d -> %d), n_layer=%d\n", i_src, i_dst, (int)hparams.n_layer);
|
||||
fflush(stderr);
|
||||
|
||||
// Copy recurrent state via CPU staging buffer.
|
||||
// Direct GPU-to-GPU copy via ggml_backend_tensor_copy crashes
|
||||
// when called from find_slot during graph execution. Use CPU
|
||||
// as intermediary: GPU→CPU (tensor_get) then CPU→GPU (tensor_set).
|
||||
for (uint32_t il = 0; il < hparams.n_layer; ++il) {
|
||||
if (r_l[il]) {
|
||||
ggml_context * ctx = ggml_init(params);
|
||||
size_t r_row_size = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
|
||||
ggml_tensor * src_v = ggml_view_1d(ctx, r_l[il], r_row_size, i_src * r_row_size);
|
||||
ggml_tensor * dst_v = ggml_view_1d(ctx, r_l[il], r_row_size, i_dst * r_row_size);
|
||||
ggml_backend_tensor_copy(src_v, dst_v);
|
||||
ggml_free(ctx);
|
||||
size_t src_offset = (size_t)i_src * r_row_size;
|
||||
size_t dst_offset = (size_t)i_dst * r_row_size;
|
||||
|
||||
std::vector<uint8_t> buf(r_row_size);
|
||||
ggml_backend_tensor_get(r_l[il], buf.data(), src_offset, r_row_size);
|
||||
ggml_backend_tensor_set(r_l[il], buf.data(), dst_offset, r_row_size);
|
||||
}
|
||||
if (s_l[il]) {
|
||||
ggml_context * ctx = ggml_init(params);
|
||||
size_t s_row_size = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
|
||||
ggml_tensor * src_v = ggml_view_1d(ctx, s_l[il], s_row_size, i_src * s_row_size);
|
||||
ggml_tensor * dst_v = ggml_view_1d(ctx, s_l[il], s_row_size, i_dst * s_row_size);
|
||||
ggml_backend_tensor_copy(src_v, dst_v);
|
||||
ggml_free(ctx);
|
||||
size_t src_offset = (size_t)i_src * s_row_size;
|
||||
size_t dst_offset = (size_t)i_dst * s_row_size;
|
||||
|
||||
std::vector<uint8_t> buf(s_row_size);
|
||||
ggml_backend_tensor_get(s_l[il], buf.data(), src_offset, s_row_size);
|
||||
ggml_backend_tensor_set(s_l[il], buf.data(), dst_offset, s_row_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -545,6 +557,10 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|||
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
|
||||
const uint32_t n_seqs = ubatch.n_seqs;
|
||||
|
||||
fprintf(stderr, "[MTP-FINDSLOT] find_slot: n_seq_tokens=%d, n_seqs=%d, size=%d, used=%d, head=%d\n",
|
||||
(int)n_seq_tokens, (int)n_seqs, (int)size, (int)used, (int)head);
|
||||
fflush(stderr);
|
||||
|
||||
// if we have enough unused cells before the current head ->
|
||||
// better to start searching from the beginning of the cache, hoping to fill it
|
||||
if (head > used + 2*n_seqs) {
|
||||
|
|
@ -688,6 +704,9 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|||
// overwritten by new tokens. This is required for speculative decoding rollback
|
||||
// in recurrent/SSM models where tensor state cannot be partially rewound.
|
||||
const int32_t cur_tail = seq_meta.tail;
|
||||
fprintf(stderr, "[MTP-FINDSLOT] checkpoint branch: seq_id=%d, cur_tail=%d, next_empty=%d\n",
|
||||
(int)seq_id, cur_tail, (int)next_empty_cell);
|
||||
fflush(stderr);
|
||||
if (cells[next_empty_cell].is_empty()) {
|
||||
bool can_checkpoint = (get_cell_count(seq_id) < 8 && used < size * 0.9);
|
||||
if (!can_checkpoint) {
|
||||
|
|
@ -712,11 +731,18 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|||
}
|
||||
if (can_checkpoint) {
|
||||
auto & cp_cell = cells[next_empty_cell];
|
||||
// Copy the current recurrent state as a checkpoint.
|
||||
// This must happen before the graph overwrites the state
|
||||
// with new token processing. The copy is safe here because
|
||||
// find_slot runs BEFORE the compute graph for this ubatch.
|
||||
copy_cell(cur_tail, next_empty_cell);
|
||||
cp_cell.pos = cells[cur_tail].pos;
|
||||
cp_cell.src = next_empty_cell; // independent copy, no further movement needed
|
||||
cp_cell.src = next_empty_cell; // independent copy
|
||||
cp_cell.seq_id.insert(seq_id);
|
||||
used++;
|
||||
fprintf(stderr, "[MTP-FINDSLOT] checkpoint at cell %d (copied from %d), pos=%d, used=%d\n",
|
||||
(int)next_empty_cell, cur_tail, (int)cp_cell.pos, (int)used);
|
||||
fflush(stderr);
|
||||
// advance next_empty_cell for subsequent sequences in this batch
|
||||
if (s + 1 < n_seqs) {
|
||||
for (uint32_t j = 0; j < size; ++j) {
|
||||
|
|
|
|||
|
|
@ -8112,6 +8112,13 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||
|
||||
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
||||
// Use hybrid-iswa for hybrid models with SWA
|
||||
// For MTP speculative decoding, we need extra recurrent state
|
||||
// cells for checkpoint/restore. Each sequence needs at least
|
||||
// 1 active cell + 1 checkpoint cell per MTP draft step.
|
||||
const uint32_t n_mtp = hparams.nextn_predict_layers;
|
||||
const uint32_t rs_per_seq = 1 + (n_mtp > 0 ? 2 : 0); // active + checkpoint room
|
||||
const uint32_t rs_size = std::max((uint32_t) 1, cparams.n_seq_max * rs_per_seq);
|
||||
|
||||
res = new llama_memory_hybrid_iswa(
|
||||
/* model */ *this,
|
||||
/* attn_type_k */ params.type_k,
|
||||
|
|
@ -8123,13 +8130,18 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||
/* attn_n_pad */ 1,
|
||||
/* recurrent_type_r */ GGML_TYPE_F32,
|
||||
/* recurrent_type_s */ GGML_TYPE_F32,
|
||||
/* recurrent_rs_size */ std::max((uint32_t) 1, cparams.n_seq_max),
|
||||
/* recurrent_rs_size */ rs_size,
|
||||
/* n_seq_max */ cparams.n_seq_max,
|
||||
/* offload */ cparams.offload_kqv,
|
||||
/* unified */ cparams.kv_unified,
|
||||
/* filter_attn */ std::move(filter_attn),
|
||||
/* filter_recr */ std::move(filter_recr));
|
||||
} else {
|
||||
// Same MTP checkpoint room for non-SWA path
|
||||
const uint32_t n_mtp2 = hparams.nextn_predict_layers;
|
||||
const uint32_t rs_per_seq2 = 1 + (n_mtp2 > 0 ? 2 : 0);
|
||||
const uint32_t rs_size2 = std::max((uint32_t) 1, cparams.n_seq_max * rs_per_seq2);
|
||||
|
||||
res = new llama_memory_hybrid(
|
||||
/* model */ *this,
|
||||
/* attn_type_k */ params.type_k,
|
||||
|
|
@ -8141,7 +8153,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||
/* attn_swa_type */ hparams.swa_type,
|
||||
/* recurrent_type_k */ GGML_TYPE_F32,
|
||||
/* recurrent_type_v */ GGML_TYPE_F32,
|
||||
/* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
|
||||
/* recurrent_kv_size */ rs_size2,
|
||||
/* n_seq_max */ cparams.n_seq_max,
|
||||
/* offload */ cparams.offload_kqv,
|
||||
/* unified */ cparams.kv_unified,
|
||||
|
|
|
|||
|
|
@ -2769,6 +2769,14 @@ private:
|
|||
};
|
||||
|
||||
fprintf(stderr, "[MTP-DBG] llama_decode: n_tokens=%d, batch_start=%d\n", n_tokens, i);
|
||||
// Print batch details for 2-token (speculative) batches
|
||||
if (n_tokens == 2) {
|
||||
for (int bi = 0; bi < n_tokens; bi++) {
|
||||
fprintf(stderr, "[MTP-DBG] batch[%d]: token=%d, pos=%d, seq_id=%d, logits=%d\n",
|
||||
bi, (int)batch.token[i+bi], (int)batch.pos[i+bi],
|
||||
(int)batch.seq_id[i+bi][0], (int)batch.logits[i+bi]);
|
||||
}
|
||||
}
|
||||
fflush(stderr);
|
||||
|
||||
const int ret = llama_decode(ctx, batch_view);
|
||||
|
|
@ -2890,7 +2898,11 @@ private:
|
|||
slot.state = SLOT_STATE_GENERATING;
|
||||
|
||||
if (slot.can_speculate()) {
|
||||
fprintf(stderr, "[MTP-DBG] speculative_begin for slot %d\n", slot.id);
|
||||
fflush(stderr);
|
||||
common_speculative_begin(slot.spec, slot.prompt.tokens.get_text_tokens());
|
||||
fprintf(stderr, "[MTP-DBG] speculative_begin done\n");
|
||||
fflush(stderr);
|
||||
}
|
||||
} else if (slot.state != SLOT_STATE_GENERATING) {
|
||||
continue; // continue loop of slots
|
||||
|
|
|
|||
Loading…
Reference in New Issue