From 279e6c721eae30b1d8ee53304e58a15a2c72354d Mon Sep 17 00:00:00 2001 From: itigges22 Date: Fri, 20 Mar 2026 12:47:52 -0400 Subject: [PATCH] fix: CPU staging copy for recurrent state checkpoint (fixes crash) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Root cause found: copy_cell crashes during find_slot because it calls ggml_backend_tensor_copy on GPU tensors while the compute graph is being built. Fixed by using CPU staging: tensor_get (GPU→CPU) then tensor_set (CPU→GPU). Also increased rs_size from 1 to 3 cells per sequence to make room for checkpoint cells needed by speculative decoding rollback. Results: - No more crashes during speculative decode - 23.8 tok/s with MTP (vs 16.7 without) - 75% acceptance rate - Output still garbled on long generation due to seq_rm not finding checkpoints at the right positions (checkpoint position mismatch) Next: fix checkpoint position tracking so seq_rm can find and restore the correct recurrent state after draft rejection. --- src/llama-memory-recurrent.cpp | 62 +++++++++++++++++++++++---------- src/llama-model.cpp | 16 +++++++-- tools/server/server-context.cpp | 12 +++++++ 3 files changed, 70 insertions(+), 20 deletions(-) diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index 4d62ad9e08..aef2b0bdf4 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -10,6 +10,7 @@ #include #include #include + #include // @@ -165,18 +166,26 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) { // for speculative decoding, we search for a checkpoint in the history int32_t best_cell = -1; + fprintf(stderr, "[MTP-SEQRM] seq_id=%d, p0=%d, p1=%d, tail_pos=%d, searching for checkpoint at pos=%d\n", + (int)seq_id, (int)p0, (int)p1, (int)cell.pos, (int)(p0-1)); for (uint32_t i = 0; i < size; ++i) { + if (cells[i].has_seq_id(seq_id)) { + fprintf(stderr, "[MTP-SEQRM] cell[%d] pos=%d\n", i, (int)cells[i].pos); + } if (cells[i].has_seq_id(seq_id) && cells[i].pos == p0 - 1) { best_cell = i; break; } } + fflush(stderr); if (best_cell >= 0) { + fprintf(stderr, "[MTP-SEQRM] FOUND checkpoint at cell[%d] pos=%d — rolling back\n", best_cell, (int)(p0-1)); + fflush(stderr); tail_id = best_cell; } else { - // No checkpoint found at p0-1: SSM tensor state cannot be rolled back - // without re-evaluating the sequence. Signal failure to the caller. + fprintf(stderr, "[MTP-SEQRM] NO checkpoint found — seq_rm FAILED\n"); + fflush(stderr); return false; } } @@ -423,28 +432,31 @@ void llama_memory_recurrent::copy_cell(int32_t i_src, int32_t i_dst) { return; } - ggml_init_params params = { - /*.mem_size =*/ size_t(2*ggml_tensor_overhead()), - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, - }; + fprintf(stderr, "[MTP-COPYCELL] copy_cell(%d -> %d), n_layer=%d\n", i_src, i_dst, (int)hparams.n_layer); + fflush(stderr); + // Copy recurrent state via CPU staging buffer. + // Direct GPU-to-GPU copy via ggml_backend_tensor_copy crashes + // when called from find_slot during graph execution. Use CPU + // as intermediary: GPU→CPU (tensor_get) then CPU→GPU (tensor_set). for (uint32_t il = 0; il < hparams.n_layer; ++il) { if (r_l[il]) { - ggml_context * ctx = ggml_init(params); size_t r_row_size = ggml_row_size(r_l[il]->type, hparams.n_embd_r()); - ggml_tensor * src_v = ggml_view_1d(ctx, r_l[il], r_row_size, i_src * r_row_size); - ggml_tensor * dst_v = ggml_view_1d(ctx, r_l[il], r_row_size, i_dst * r_row_size); - ggml_backend_tensor_copy(src_v, dst_v); - ggml_free(ctx); + size_t src_offset = (size_t)i_src * r_row_size; + size_t dst_offset = (size_t)i_dst * r_row_size; + + std::vector buf(r_row_size); + ggml_backend_tensor_get(r_l[il], buf.data(), src_offset, r_row_size); + ggml_backend_tensor_set(r_l[il], buf.data(), dst_offset, r_row_size); } if (s_l[il]) { - ggml_context * ctx = ggml_init(params); size_t s_row_size = ggml_row_size(s_l[il]->type, hparams.n_embd_s()); - ggml_tensor * src_v = ggml_view_1d(ctx, s_l[il], s_row_size, i_src * s_row_size); - ggml_tensor * dst_v = ggml_view_1d(ctx, s_l[il], s_row_size, i_dst * s_row_size); - ggml_backend_tensor_copy(src_v, dst_v); - ggml_free(ctx); + size_t src_offset = (size_t)i_src * s_row_size; + size_t dst_offset = (size_t)i_dst * s_row_size; + + std::vector buf(s_row_size); + ggml_backend_tensor_get(s_l[il], buf.data(), src_offset, s_row_size); + ggml_backend_tensor_set(s_l[il], buf.data(), dst_offset, s_row_size); } } } @@ -545,6 +557,10 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) { const uint32_t n_seq_tokens = ubatch.n_seq_tokens; const uint32_t n_seqs = ubatch.n_seqs; + fprintf(stderr, "[MTP-FINDSLOT] find_slot: n_seq_tokens=%d, n_seqs=%d, size=%d, used=%d, head=%d\n", + (int)n_seq_tokens, (int)n_seqs, (int)size, (int)used, (int)head); + fflush(stderr); + // if we have enough unused cells before the current head -> // better to start searching from the beginning of the cache, hoping to fill it if (head > used + 2*n_seqs) { @@ -688,6 +704,9 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) { // overwritten by new tokens. This is required for speculative decoding rollback // in recurrent/SSM models where tensor state cannot be partially rewound. const int32_t cur_tail = seq_meta.tail; + fprintf(stderr, "[MTP-FINDSLOT] checkpoint branch: seq_id=%d, cur_tail=%d, next_empty=%d\n", + (int)seq_id, cur_tail, (int)next_empty_cell); + fflush(stderr); if (cells[next_empty_cell].is_empty()) { bool can_checkpoint = (get_cell_count(seq_id) < 8 && used < size * 0.9); if (!can_checkpoint) { @@ -712,11 +731,18 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) { } if (can_checkpoint) { auto & cp_cell = cells[next_empty_cell]; + // Copy the current recurrent state as a checkpoint. + // This must happen before the graph overwrites the state + // with new token processing. The copy is safe here because + // find_slot runs BEFORE the compute graph for this ubatch. copy_cell(cur_tail, next_empty_cell); cp_cell.pos = cells[cur_tail].pos; - cp_cell.src = next_empty_cell; // independent copy, no further movement needed + cp_cell.src = next_empty_cell; // independent copy cp_cell.seq_id.insert(seq_id); used++; + fprintf(stderr, "[MTP-FINDSLOT] checkpoint at cell %d (copied from %d), pos=%d, used=%d\n", + (int)next_empty_cell, cur_tail, (int)cp_cell.pos, (int)used); + fflush(stderr); // advance next_empty_cell for subsequent sequences in this batch if (s + 1 < n_seqs) { for (uint32_t j = 0; j < size; ++j) { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c29b466ba3..027e72cc17 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -8112,6 +8112,13 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { // Use hybrid-iswa for hybrid models with SWA + // For MTP speculative decoding, we need extra recurrent state + // cells for checkpoint/restore. Each sequence needs at least + // 1 active cell + 1 checkpoint cell per MTP draft step. + const uint32_t n_mtp = hparams.nextn_predict_layers; + const uint32_t rs_per_seq = 1 + (n_mtp > 0 ? 2 : 0); // active + checkpoint room + const uint32_t rs_size = std::max((uint32_t) 1, cparams.n_seq_max * rs_per_seq); + res = new llama_memory_hybrid_iswa( /* model */ *this, /* attn_type_k */ params.type_k, @@ -8123,13 +8130,18 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* attn_n_pad */ 1, /* recurrent_type_r */ GGML_TYPE_F32, /* recurrent_type_s */ GGML_TYPE_F32, - /* recurrent_rs_size */ std::max((uint32_t) 1, cparams.n_seq_max), + /* recurrent_rs_size */ rs_size, /* n_seq_max */ cparams.n_seq_max, /* offload */ cparams.offload_kqv, /* unified */ cparams.kv_unified, /* filter_attn */ std::move(filter_attn), /* filter_recr */ std::move(filter_recr)); } else { + // Same MTP checkpoint room for non-SWA path + const uint32_t n_mtp2 = hparams.nextn_predict_layers; + const uint32_t rs_per_seq2 = 1 + (n_mtp2 > 0 ? 2 : 0); + const uint32_t rs_size2 = std::max((uint32_t) 1, cparams.n_seq_max * rs_per_seq2); + res = new llama_memory_hybrid( /* model */ *this, /* attn_type_k */ params.type_k, @@ -8141,7 +8153,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* attn_swa_type */ hparams.swa_type, /* recurrent_type_k */ GGML_TYPE_F32, /* recurrent_type_v */ GGML_TYPE_F32, - /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), + /* recurrent_kv_size */ rs_size2, /* n_seq_max */ cparams.n_seq_max, /* offload */ cparams.offload_kqv, /* unified */ cparams.kv_unified, diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index d848ab3005..dd6d5af86f 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2769,6 +2769,14 @@ private: }; fprintf(stderr, "[MTP-DBG] llama_decode: n_tokens=%d, batch_start=%d\n", n_tokens, i); + // Print batch details for 2-token (speculative) batches + if (n_tokens == 2) { + for (int bi = 0; bi < n_tokens; bi++) { + fprintf(stderr, "[MTP-DBG] batch[%d]: token=%d, pos=%d, seq_id=%d, logits=%d\n", + bi, (int)batch.token[i+bi], (int)batch.pos[i+bi], + (int)batch.seq_id[i+bi][0], (int)batch.logits[i+bi]); + } + } fflush(stderr); const int ret = llama_decode(ctx, batch_view); @@ -2890,7 +2898,11 @@ private: slot.state = SLOT_STATE_GENERATING; if (slot.can_speculate()) { + fprintf(stderr, "[MTP-DBG] speculative_begin for slot %d\n", slot.id); + fflush(stderr); common_speculative_begin(slot.spec, slot.prompt.tokens.get_text_tokens()); + fprintf(stderr, "[MTP-DBG] speculative_begin done\n"); + fflush(stderr); } } else if (slot.state != SLOT_STATE_GENERATING) { continue; // continue loop of slots