server : support multi-modal context checkpoints (#19849)
* Modify llama-memory-hybrid-iswa.cpp * Modify llama-memory-recurrent.cpp * Modify server-common.cpp * Modify server-common.h * Modify server-context.cpp * Modify server-task.h * Added comment to llama-memory-hybrid-iswa.cpp * Remove comment from server-context.cpp * Stylistic fix server-context.cpp * Fix an issue when seqrm isn't called in server-context.cpp * cont : alternative impl * cont : cleanup * cont : n_tokens -> int64_t --------- Co-authored-by: timkhronos <timkhronos@gmail.com>
This commit is contained in:
parent
c747294b2d
commit
d7d826b3c1
|
|
@ -163,7 +163,7 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
|
|||
const auto & cell = cells[tail_id];
|
||||
// partial intersection is invalid if it includes the final pos
|
||||
if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) {
|
||||
//printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n");
|
||||
//printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false, p0 = %d, cell.pos = %d, p1 = %d\n", p0, cell.pos, p1);
|
||||
return false;
|
||||
}
|
||||
// invalidate tails which will be cleared
|
||||
|
|
|
|||
|
|
@ -231,19 +231,77 @@ server_tokens::server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd) :
|
|||
server_tokens::server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {
|
||||
}
|
||||
|
||||
llama_pos server_tokens::pos_next() const {
|
||||
llama_pos server_tokens::pos_next(int64_t n_tokens) const {
|
||||
if (!has_mtmd) {
|
||||
return tokens.size();
|
||||
if (n_tokens < 0) {
|
||||
return tokens.size();
|
||||
}
|
||||
|
||||
return n_tokens;
|
||||
}
|
||||
|
||||
llama_pos res = tokens.size();
|
||||
if (n_tokens < 0) {
|
||||
llama_pos res = tokens.size();
|
||||
|
||||
for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) {
|
||||
const auto & chunk = it->second;
|
||||
res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get());
|
||||
for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) {
|
||||
const auto & chunk = it->second;
|
||||
res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get());
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
return res;
|
||||
int64_t idx = 0;
|
||||
llama_pos pos = 0;
|
||||
|
||||
GGML_ASSERT(n_tokens <= (int64_t)tokens.size());
|
||||
|
||||
while (idx < n_tokens) {
|
||||
const auto media_it = map_idx_to_media.find(idx);
|
||||
if (media_it != map_idx_to_media.end()) {
|
||||
const auto & chunk = media_it->second;
|
||||
const llama_pos n_pos = mtmd_input_chunk_get_n_pos(chunk.get());
|
||||
const size_t n_tok = mtmd_input_chunk_get_n_tokens(chunk.get());
|
||||
|
||||
pos += n_pos;
|
||||
idx += n_tok;
|
||||
} else {
|
||||
pos++;
|
||||
idx++;
|
||||
}
|
||||
}
|
||||
|
||||
return pos;
|
||||
}
|
||||
|
||||
size_t server_tokens::size_up_to_pos(llama_pos max_pos) const {
|
||||
if (!has_mtmd) {
|
||||
return std::min((size_t)(max_pos + 1), tokens.size());
|
||||
}
|
||||
|
||||
size_t idx = 0;
|
||||
llama_pos pos = 0;
|
||||
|
||||
while (idx < tokens.size()) {
|
||||
const auto media_it = map_idx_to_media.find(idx);
|
||||
if (media_it != map_idx_to_media.end()) {
|
||||
const auto & chunk = media_it->second;
|
||||
const llama_pos n_pos = mtmd_input_chunk_get_n_pos(chunk.get());
|
||||
const size_t n_tok = mtmd_input_chunk_get_n_tokens(chunk.get());
|
||||
|
||||
pos += n_pos;
|
||||
idx += n_tok;
|
||||
} else {
|
||||
pos++;
|
||||
idx++;
|
||||
}
|
||||
|
||||
if (pos > max_pos) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return idx;
|
||||
}
|
||||
|
||||
std::string server_tokens::str() const {
|
||||
|
|
|
|||
|
|
@ -167,7 +167,12 @@ public:
|
|||
// for debugging
|
||||
std::string str() const;
|
||||
|
||||
llama_pos pos_next() const;
|
||||
// the next position after n_tokens. if n_tokens < 0, return the next position after all tokens.
|
||||
llama_pos pos_next(int64_t n_tokens = -1) const;
|
||||
|
||||
// number of tokens with position <= max_pos
|
||||
size_t size_up_to_pos(llama_pos max_pos) const;
|
||||
|
||||
const mtmd::input_chunk_ptr & find_chunk(size_t idx) const;
|
||||
|
||||
void push_back(llama_token tok);
|
||||
|
|
|
|||
|
|
@ -1442,7 +1442,7 @@ private:
|
|||
res->id = slot.task->id;
|
||||
res->id_slot = slot.id;
|
||||
|
||||
res->index = slot.task->index;
|
||||
res->index = slot.task->index;
|
||||
|
||||
// keep copy of last generated text for debugging purposes
|
||||
if (slots_debug) {
|
||||
|
|
@ -2282,15 +2282,15 @@ private:
|
|||
n_past = 0;
|
||||
}
|
||||
|
||||
llama_pos pos_next = slot.prompt.tokens.pos_next(n_past);
|
||||
|
||||
// note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
|
||||
const auto n_swa = std::max(1, llama_model_n_swa(model));
|
||||
|
||||
// the largest pos_min required for a checkpoint to be useful
|
||||
const auto pos_min_thold = std::max(0, n_past - n_swa);
|
||||
const auto pos_min_thold = std::max(0, pos_next - n_swa);
|
||||
|
||||
// note: disallow with mtmd contexts for now
|
||||
// https://github.com/ggml-org/llama.cpp/issues/17043
|
||||
if (!mctx && n_past > 0 && n_past < slot.prompt.n_tokens()) {
|
||||
if (n_past > 0 && n_past < slot.prompt.n_tokens()) {
|
||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
||||
if (pos_min == -1) {
|
||||
SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min);
|
||||
|
|
@ -2341,9 +2341,6 @@ private:
|
|||
}
|
||||
|
||||
if (pos_min > pos_min_thold) {
|
||||
// TODO: support can be added in the future when corresponding vision models get released
|
||||
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
|
||||
|
||||
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
|
||||
|
||||
// search for a context checkpoint
|
||||
|
|
@ -2364,18 +2361,20 @@ private:
|
|||
const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
if (n != checkpoint_size) {
|
||||
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
|
||||
SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, (float) checkpoint_size / 1024 / 1024);
|
||||
do_reset = true;
|
||||
//printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint");
|
||||
} else {
|
||||
n_past = std::min(n_past, std::max(it->pos_min + 1, it->pos_max));
|
||||
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024);
|
||||
pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max));
|
||||
n_past = slot.prompt.tokens.size_up_to_pos(pos_next);
|
||||
SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, (float) checkpoint_size / 1024 / 1024);
|
||||
}
|
||||
}
|
||||
|
||||
if (do_reset) {
|
||||
SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n",
|
||||
"https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
|
||||
pos_next = 0;
|
||||
n_past = 0;
|
||||
}
|
||||
}
|
||||
|
|
@ -2386,7 +2385,7 @@ private:
|
|||
for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) {
|
||||
const auto & cur = *it;
|
||||
if (cur.pos_min > pos_min_thold) {
|
||||
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024);
|
||||
SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, (float) cur.data.size() / 1024 / 1024);
|
||||
it = slot.prompt.checkpoints.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
|
|
@ -2402,7 +2401,7 @@ private:
|
|||
SLT_WRN(slot, "n_past was set to %d\n", n_past);
|
||||
}
|
||||
|
||||
slot.n_prompt_tokens_cache = n_past;
|
||||
slot.n_prompt_tokens_cache = n_past;
|
||||
slot.n_prompt_tokens_processed = 0;
|
||||
|
||||
slot.prompt.tokens.keep_first(n_past);
|
||||
|
|
@ -2520,10 +2519,6 @@ private:
|
|||
}
|
||||
}
|
||||
|
||||
// SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str());
|
||||
|
||||
SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
|
||||
|
||||
// entire prompt has been processed
|
||||
if (slot.prompt.n_tokens() == slot.task->n_tokens()) {
|
||||
slot.state = SLOT_STATE_DONE_PROMPT;
|
||||
|
|
@ -2536,8 +2531,6 @@ private:
|
|||
slot.n_decoded = 0;
|
||||
slot.i_batch = batch.n_tokens - 1;
|
||||
|
||||
SLT_INF(slot, "prompt done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
|
||||
|
||||
slot.init_sampler();
|
||||
|
||||
const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id);
|
||||
|
|
@ -2549,13 +2542,15 @@ private:
|
|||
// no need to create checkpoints that are too close together
|
||||
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64);
|
||||
|
||||
// note: we create the checkpoint before calling llama_decode(), so the current batch is not
|
||||
// yet processed and therefore it is not part of the checkpoint.
|
||||
if (do_checkpoint) {
|
||||
while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) {
|
||||
// make room for the new checkpoint, if needed
|
||||
const auto & cur = slot.prompt.checkpoints.front();
|
||||
|
||||
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
||||
cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
|
||||
SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
|
||||
cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
|
||||
|
||||
slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin());
|
||||
}
|
||||
|
|
@ -2563,16 +2558,21 @@ private:
|
|||
const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{
|
||||
/*.pos_min = */ pos_min,
|
||||
/*.pos_max = */ pos_max,
|
||||
/*.data = */ std::vector<uint8_t>(checkpoint_size),
|
||||
/*.pos_min = */ pos_min,
|
||||
/*.pos_max = */ pos_max,
|
||||
/*.n_tokens = */ slot.prompt.n_tokens() - batch.n_tokens,
|
||||
/*.data = */ std::vector<uint8_t>(checkpoint_size),
|
||||
});
|
||||
|
||||
llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
|
||||
|
||||
SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n",
|
||||
(int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024);
|
||||
SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
|
||||
(int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024);
|
||||
}
|
||||
|
||||
SLT_INF(slot, "prompt processing done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens);
|
||||
} else {
|
||||
SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -557,6 +557,8 @@ struct server_prompt_checkpoint {
|
|||
llama_pos pos_min;
|
||||
llama_pos pos_max;
|
||||
|
||||
int64_t n_tokens;
|
||||
|
||||
std::vector<uint8_t> data;
|
||||
|
||||
size_t size() const {
|
||||
|
|
|
|||
Loading…
Reference in New Issue