diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 4cbbca0925..3186242d60 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1561,9 +1561,11 @@ void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id]; + slot_info sinfo; + bool res = true; - res = res && state_read_meta(io, strm, cell_count, seq_id); - res = res && state_read_data(io, strm, cell_count); + res = res && state_read_meta(io, strm, cell_count, sinfo, seq_id); + res = res && state_read_data(io, strm, cell_count, sinfo); if (!res) { if (seq_id == -1) { @@ -1702,7 +1704,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t } } -bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) { +bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id) { auto & cells = v_cells[strm]; auto & head = v_heads[strm]; @@ -1739,7 +1741,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 ubatch.seq_id[i] = &dest_seq_id; } - const auto sinfo = find_slot(ubatch, true); + sinfo = find_slot(ubatch, false); if (sinfo.empty()) { LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); return false; @@ -1749,20 +1751,16 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 // see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350 apply_ubatch(sinfo, ubatch); - const auto head_cur = sinfo.head(); + LLAMA_LOG_DEBUG("%s: cell_count = %d, dest_seq_id = %d\n", __func__, cell_count, dest_seq_id); - // keep the head at the old position because we will read the KV data into it in state_read_data() - head = head_cur; - - LLAMA_LOG_DEBUG("%s: head_cur = %d, head = %d, cell_count = %d, dest_seq_id = %d\n", __func__, head_cur, head, cell_count, dest_seq_id); - - // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values) - // Assume that this is one contiguous block of cells - GGML_ASSERT(head_cur + cell_count <= cells.size()); - GGML_ASSERT(cells.pos_get(head_cur) == ubatch.pos[0]); - GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]); - GGML_ASSERT(cells.seq_has(head_cur, dest_seq_id)); - GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id)); + // DEBUG CHECK: verify that all cells were allocated and have correct seq_id and pos values + GGML_ASSERT(sinfo.n_stream() == 1); + GGML_ASSERT(sinfo.idxs[0].size() == cell_count); + for (uint32_t i = 0; i < cell_count; ++i) { + const uint32_t idx = sinfo.idxs[0][i]; + GGML_ASSERT(cells.pos_get(idx) == ubatch.pos[i]); + GGML_ASSERT(cells.seq_has(idx, dest_seq_id)); + } } else { // whole KV cache restore @@ -1795,15 +1793,24 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 } } + // Create contiguous slot_info for whole cache restore + sinfo.s0 = strm; + sinfo.s1 = strm; + sinfo.resize(1); + sinfo.strm[0] = strm; + sinfo.idxs[0].resize(cell_count); + for (uint32_t i = 0; i < cell_count; ++i) { + sinfo.idxs[0][i] = i; + } + head = 0; } return true; } -bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) { +bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, const slot_info & sinfo) { auto & cells = v_cells[strm]; - auto & head = v_heads[strm]; uint32_t v_trans; uint32_t n_layer; @@ -1853,8 +1860,17 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 } if (cell_count) { - // Read and set the keys for the whole cell range - ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); + if (sinfo.is_contiguous()) { + // Fast path: contiguous cells, single memcpy + ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), sinfo.head() * k_size_row, cell_count * k_size_row); + } else { + // Slow path: scatter to non-contiguous positions + const void * src = io.read(cell_count * k_size_row); + for (uint32_t i = 0; i < cell_count; ++i) { + const size_t dst_offset = sinfo.idxs[0][i] * k_size_row; + ggml_backend_tensor_set(k, (const char*)src + i * k_size_row, dst_offset, k_size_row); + } + } } } @@ -1885,8 +1901,17 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 } if (cell_count) { - // Read and set the values for the whole cell range - ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); + if (sinfo.is_contiguous()) { + // Fast path: contiguous cells, single memcpy + ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), sinfo.head() * v_size_row, cell_count * v_size_row); + } else { + // Slow path: scatter to non-contiguous positions + const void * src = io.read(cell_count * v_size_row); + for (uint32_t i = 0; i < cell_count; ++i) { + const size_t dst_offset = sinfo.idxs[0][i] * v_size_row; + ggml_backend_tensor_set(v, (const char*)src + i * v_size_row, dst_offset, v_size_row); + } + } } } } else { @@ -1925,10 +1950,22 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 } if (cell_count) { - // For each row in the transposed matrix, read the values for the whole cell range - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - const size_t dst_offset = (head + j * cells.size()) * v_size_el; - ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); + if (sinfo.is_contiguous()) { + // Fast path: contiguous cells + const uint32_t h = sinfo.head(); + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const size_t dst_offset = (h + j * cells.size()) * v_size_el; + ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); + } + } else { + // Slow path: scatter to non-contiguous positions + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const void * src = io.read(cell_count * v_size_el); + for (uint32_t i = 0; i < cell_count; ++i) { + const size_t dst_offset = (sinfo.idxs[0][i] + j * cells.size()) * v_size_el; + ggml_backend_tensor_set(v, (const char*)src + i * v_size_el, dst_offset, v_size_el); + } + } } } } diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index bf7821c07c..1868f11857 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -72,6 +72,23 @@ public: void clear() { idxs.clear(); } + + // check if indices are contiguous starting from head() + bool is_contiguous() const { + if (idxs.empty() || idxs[0].empty()) { + return true; + } + if (idxs.size() > 1) { + return false; + } + const uint32_t h = idxs[0][0]; + for (size_t i = 0; i < idxs[0].size(); ++i) { + if (idxs[0][i] != h + i) { + return false; + } + } + return true; + } }; using slot_info_vec_t = std::vector; @@ -264,8 +281,8 @@ private: void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const; void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const; - bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1); - bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count); + bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id = -1); + bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, const slot_info & sinfo); }; class llama_kv_cache_context : public llama_memory_context_i { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 9ba559c8df..c3d9f9c324 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -222,6 +222,14 @@ llama_build_and_test(test-backend-ops.cpp) llama_build_and_test(test-model-load-cancel.cpp LABEL "model") llama_build_and_test(test-autorelease.cpp LABEL "model") +# Test for state restore with fragmented KV cache +# Requires a model, uses same args pattern as test-thread-safety +if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x") + llama_build_and_test(test-state-restore-fragmented.cpp LABEL "model" ARGS -hf ggml-org/models -hff tinyllamas/stories15M-q4_0.gguf) +else() + llama_build_and_test(test-state-restore-fragmented.cpp LABEL "model" ARGS -hf ggml-org/models -hff tinyllamas/stories15M-be.Q4_0.gguf) +endif() + if (NOT GGML_BACKEND_DL) # these tests use the backends directly and cannot be built with dynamic loading llama_build_and_test(test-barrier.cpp) diff --git a/tests/test-state-restore-fragmented.cpp b/tests/test-state-restore-fragmented.cpp new file mode 100644 index 0000000000..481b39d04c --- /dev/null +++ b/tests/test-state-restore-fragmented.cpp @@ -0,0 +1,122 @@ +// Test for state restore with fragmented KV cache +// This tests the fix for: https://github.com/ggml-org/llama.cpp/issues/17527 +// The issue was that state restore required contiguous KV cache slots, +// which fails when the cache is fragmented. +// +// The fix changes find_slot(ubatch, true) to find_slot(ubatch, false) +// in state_read_meta(), allowing non-contiguous slot allocation. + +#include "arg.h" +#include "common.h" +#include "llama.h" + +#include +#include +#include + +int main(int argc, char ** argv) { + common_params params; + + params.sampling.seed = 1234; + params.kv_unified = true; + params.n_parallel = 3; + params.n_ctx = 256; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { + return 1; + } + + common_init(); + + // init + common_init_result_ptr llama_init = common_init_from_params(params); + + llama_model * model = llama_init->model(); + llama_context * ctx = llama_init->context(); + + if (model == nullptr || ctx == nullptr) { + fprintf(stderr, "%s : failed to init\n", __func__); + return 1; + } + + GGML_UNUSED(model); + + // tokenize prompt + std::vector tokens(70, 1); + + // interleave the 3 sequences: + // 01201230123... + llama_batch batch = llama_batch_init(params.n_parallel*tokens.size(), 0, 1); + for (size_t i = 0; i < tokens.size(); i++) { + for (int s = 0; s < params.n_parallel; ++s) { + common_batch_add(batch, tokens[i], i, {s}, false); + } + } + batch.logits[batch.n_tokens - 1] = true; + + if (llama_decode(ctx, batch)) { + fprintf(stderr, "%s : failed to decode seq 0\n", __func__); + return 1; + } + + fprintf(stderr, "%s : processed prompt on seq 0, 1, 2 (%zu tokens each)\n", __func__, tokens.size()); + + // Save state of seq 1 + std::vector seq_state(llama_state_seq_get_size(ctx, 1)); + const size_t ncopy = llama_state_seq_get_data(ctx, seq_state.data(), seq_state.size(), 1); + if (ncopy != seq_state.size()) { + fprintf(stderr, "%s : failed to save seq 1 state\n", __func__); + return 1; + } + fprintf(stderr, "%s : saved seq 1 state, %zu bytes\n", __func__, ncopy); + + // clear seq 1 to create a "hole" in the KV cache (fragmentation) + // 0.20.20.20.2.... + llama_memory_t mem = llama_get_memory(ctx); + llama_memory_seq_rm(mem, 1, -1, -1); + fprintf(stderr, "%s : cleared seq 1 to create fragmentation\n", __func__); + + // Now the cache has holes where seq 1 was + // This creates fragmentation - there's no contiguous block large enough + // for the seq 1 state if we only look for contiguous slots + + // Restore seq 1 state into seq 1 (should work with non-contiguous allocation) + // We use seq 1 since it's a valid sequence ID (0 to n_parallel-1) + // Before the fix, this would fail with "failed to find available cells in kv cache" + const size_t nset = llama_state_seq_set_data(ctx, seq_state.data(), seq_state.size(), 1); + if (nset != seq_state.size()) { + fprintf(stderr, "%s : FAILED to restore seq state into fragmented cache (got %zu, expected %zu)\n", + __func__, nset, seq_state.size()); + fprintf(stderr, "%s : This is the bug - state restore fails with fragmented KV cache\n", __func__); + llama_batch_free(batch); + return 1; + } + fprintf(stderr, "%s : restored state into seq 1, %zu bytes\n", __func__, nset); + + // Verify we can decode with the restored state + // Generate one token to verify the restored state is usable + auto sparams = llama_sampler_chain_default_params(); + llama_sampler * smpl = llama_sampler_chain_init(sparams); + llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sampling.seed)); + + auto next_token = llama_sampler_sample(smpl, ctx, -1); + auto next_token_str = common_token_to_piece(ctx, next_token); + + common_batch_clear(batch); + common_batch_add(batch, next_token, (int)tokens.size(), {1}, true); + + if (llama_decode(ctx, batch)) { + fprintf(stderr, "%s : failed to decode with restored state\n", __func__); + llama_sampler_free(smpl); + llama_batch_free(batch); + return 1; + } + + fprintf(stderr, "%s : successfully decoded with restored state, generated: '%s'\n", __func__, next_token_str.c_str()); + fprintf(stderr, "%s : SUCCESS - state restore works with fragmented KV cache\n", __func__); + + llama_sampler_free(smpl); + llama_batch_free(batch); + + return 0; +}