diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index f0038036dc..6e8413f493 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -163,7 +163,7 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos const auto & cell = cells[tail_id]; // partial intersection is invalid if it includes the final pos if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) { - //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n"); + //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false, p0 = %d, cell.pos = %d, p1 = %d\n", p0, cell.pos, p1); return false; } // invalidate tails which will be cleared diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index 88b6e77d82..ff3c6d3c2b 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -231,19 +231,77 @@ server_tokens::server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd) : server_tokens::server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) { } -llama_pos server_tokens::pos_next() const { +llama_pos server_tokens::pos_next(int64_t n_tokens) const { if (!has_mtmd) { - return tokens.size(); + if (n_tokens < 0) { + return tokens.size(); + } + + return n_tokens; } - llama_pos res = tokens.size(); + if (n_tokens < 0) { + llama_pos res = tokens.size(); - for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) { - const auto & chunk = it->second; - res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get()); + for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) { + const auto & chunk = it->second; + res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get()); + } + + return res; } - return res; + int64_t idx = 0; + llama_pos pos = 0; + + GGML_ASSERT(n_tokens <= (int64_t)tokens.size()); + + while (idx < n_tokens) { + const auto media_it = map_idx_to_media.find(idx); + if (media_it != map_idx_to_media.end()) { + const auto & chunk = media_it->second; + const llama_pos n_pos = mtmd_input_chunk_get_n_pos(chunk.get()); + const size_t n_tok = mtmd_input_chunk_get_n_tokens(chunk.get()); + + pos += n_pos; + idx += n_tok; + } else { + pos++; + idx++; + } + } + + return pos; +} + +size_t server_tokens::size_up_to_pos(llama_pos max_pos) const { + if (!has_mtmd) { + return std::min((size_t)(max_pos + 1), tokens.size()); + } + + size_t idx = 0; + llama_pos pos = 0; + + while (idx < tokens.size()) { + const auto media_it = map_idx_to_media.find(idx); + if (media_it != map_idx_to_media.end()) { + const auto & chunk = media_it->second; + const llama_pos n_pos = mtmd_input_chunk_get_n_pos(chunk.get()); + const size_t n_tok = mtmd_input_chunk_get_n_tokens(chunk.get()); + + pos += n_pos; + idx += n_tok; + } else { + pos++; + idx++; + } + + if (pos > max_pos) { + break; + } + } + + return idx; } std::string server_tokens::str() const { diff --git a/tools/server/server-common.h b/tools/server/server-common.h index 2629a6bee9..4fb9e488df 100644 --- a/tools/server/server-common.h +++ b/tools/server/server-common.h @@ -167,7 +167,12 @@ public: // for debugging std::string str() const; - llama_pos pos_next() const; + // the next position after n_tokens. if n_tokens < 0, return the next position after all tokens. + llama_pos pos_next(int64_t n_tokens = -1) const; + + // number of tokens with position <= max_pos + size_t size_up_to_pos(llama_pos max_pos) const; + const mtmd::input_chunk_ptr & find_chunk(size_t idx) const; void push_back(llama_token tok); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 0f2f3a45aa..67c3988bd0 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -1442,7 +1442,7 @@ private: res->id = slot.task->id; res->id_slot = slot.id; - res->index = slot.task->index; + res->index = slot.task->index; // keep copy of last generated text for debugging purposes if (slots_debug) { @@ -2282,15 +2282,15 @@ private: n_past = 0; } + llama_pos pos_next = slot.prompt.tokens.pos_next(n_past); + // note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1 const auto n_swa = std::max(1, llama_model_n_swa(model)); // the largest pos_min required for a checkpoint to be useful - const auto pos_min_thold = std::max(0, n_past - n_swa); + const auto pos_min_thold = std::max(0, pos_next - n_swa); - // note: disallow with mtmd contexts for now - // https://github.com/ggml-org/llama.cpp/issues/17043 - if (!mctx && n_past > 0 && n_past < slot.prompt.n_tokens()) { + if (n_past > 0 && n_past < slot.prompt.n_tokens()) { const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); if (pos_min == -1) { SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min); @@ -2341,9 +2341,6 @@ private: } if (pos_min > pos_min_thold) { - // TODO: support can be added in the future when corresponding vision models get released - GGML_ASSERT(!slot.prompt.tokens.has_mtmd); - SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa); // search for a context checkpoint @@ -2364,18 +2361,20 @@ private: const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); if (n != checkpoint_size) { - SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024); + SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, (float) checkpoint_size / 1024 / 1024); do_reset = true; //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint"); } else { - n_past = std::min(n_past, std::max(it->pos_min + 1, it->pos_max)); - SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024); + pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max)); + n_past = slot.prompt.tokens.size_up_to_pos(pos_next); + SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, (float) checkpoint_size / 1024 / 1024); } } if (do_reset) { SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n", "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); + pos_next = 0; n_past = 0; } } @@ -2386,7 +2385,7 @@ private: for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) { const auto & cur = *it; if (cur.pos_min > pos_min_thold) { - SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024); + SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, (float) cur.data.size() / 1024 / 1024); it = slot.prompt.checkpoints.erase(it); } else { ++it; @@ -2402,7 +2401,7 @@ private: SLT_WRN(slot, "n_past was set to %d\n", n_past); } - slot.n_prompt_tokens_cache = n_past; + slot.n_prompt_tokens_cache = n_past; slot.n_prompt_tokens_processed = 0; slot.prompt.tokens.keep_first(n_past); @@ -2520,10 +2519,6 @@ private: } } - // SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str()); - - SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens()); - // entire prompt has been processed if (slot.prompt.n_tokens() == slot.task->n_tokens()) { slot.state = SLOT_STATE_DONE_PROMPT; @@ -2536,8 +2531,6 @@ private: slot.n_decoded = 0; slot.i_batch = batch.n_tokens - 1; - SLT_INF(slot, "prompt done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens); - slot.init_sampler(); const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); @@ -2549,13 +2542,15 @@ private: // no need to create checkpoints that are too close together do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64); + // note: we create the checkpoint before calling llama_decode(), so the current batch is not + // yet processed and therefore it is not part of the checkpoint. if (do_checkpoint) { while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { // make room for the new checkpoint, if needed const auto & cur = slot.prompt.checkpoints.front(); - SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", - cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); + SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", + cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024); slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin()); } @@ -2563,16 +2558,21 @@ private: const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{ - /*.pos_min = */ pos_min, - /*.pos_max = */ pos_max, - /*.data = */ std::vector(checkpoint_size), + /*.pos_min = */ pos_min, + /*.pos_max = */ pos_max, + /*.n_tokens = */ slot.prompt.n_tokens() - batch.n_tokens, + /*.data = */ std::vector(checkpoint_size), }); llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", - (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); + SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", + (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024); } + + SLT_INF(slot, "prompt processing done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens); + } else { + SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens()); } } diff --git a/tools/server/server-task.h b/tools/server/server-task.h index a69e8f1a3d..e2e3e5a582 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -557,6 +557,8 @@ struct server_prompt_checkpoint { llama_pos pos_min; llama_pos pos_max; + int64_t n_tokens; + std::vector data; size_t size() const {