server : improve mtmd ctx checkpoints (#20726)
* server : improve mtmd ctx checkpoints * server : fix off-by-one in pos_min_thold
This commit is contained in:
parent
1af9dab32b
commit
ab9d4c3678
|
|
@ -2307,8 +2307,8 @@ private:
|
|||
|
||||
llama_pos pos_next = slot.prompt.tokens.pos_next(n_past);
|
||||
|
||||
// note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1
|
||||
const auto n_swa = std::max(1, llama_model_n_swa(model));
|
||||
// note: when n_swa == 0, the model does not use SWA
|
||||
const auto n_swa = std::max(0, llama_model_n_swa(model));
|
||||
|
||||
// the largest pos_min required for a checkpoint to be useful
|
||||
const auto pos_min_thold = std::max(0, pos_next - n_swa);
|
||||
|
|
@ -2363,7 +2363,7 @@ private:
|
|||
SLT_WRN(slot, "%s\n", st1.str().c_str());
|
||||
}
|
||||
|
||||
if (pos_min > pos_min_thold) {
|
||||
if (pos_min >= pos_min_thold) {
|
||||
SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa);
|
||||
|
||||
// search for a context checkpoint
|
||||
|
|
@ -2459,8 +2459,39 @@ private:
|
|||
slot.n_prompt_tokens_cache = 0;
|
||||
}
|
||||
|
||||
// If using an alora, there may be uncached tokens that come
|
||||
// before the invocation sequence. When this happens, the
|
||||
// tokens before the invocation sequence need to be
|
||||
// processed without the adapter in a separate batch, then
|
||||
// the adapter needs to be enabled for the remaining tokens.
|
||||
if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.prompt.n_tokens()) {
|
||||
SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start);
|
||||
const auto & enabled_loras = lora_get_enabled_ids(slot.lora);
|
||||
GGML_ASSERT(enabled_loras.size() == 1);
|
||||
alora_scale = slot.lora[enabled_loras[0]].scale;
|
||||
slot.lora[enabled_loras[0]].scale = 0.0f;
|
||||
alora_disabled_id = enabled_loras[0];
|
||||
}
|
||||
|
||||
bool do_checkpoint = params_base.n_ctx_checkpoints > 0;
|
||||
|
||||
// make checkpoints only for completion tasks
|
||||
do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION;
|
||||
|
||||
// make a checkpoint of the parts of the memory that cannot be rolled back.
|
||||
// checkpoints are created only if:
|
||||
// - the model uses SWA and we are not using `swa_full`
|
||||
// - the model architecture is marked as recurrent or hybrid
|
||||
//
|
||||
// TODO: try to make this conditional on the context or the memory module, instead of the model type
|
||||
do_checkpoint = do_checkpoint && (
|
||||
llama_model_is_recurrent(model) ||
|
||||
llama_model_is_hybrid(model) ||
|
||||
(llama_model_n_swa(model) > 0 && !params_base.swa_full)
|
||||
);
|
||||
|
||||
bool has_mtmd = false;
|
||||
|
||||
// check if we should process the image
|
||||
if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
|
||||
// process the image
|
||||
|
|
@ -2481,38 +2512,9 @@ private:
|
|||
slot.prompt.tokens.push_back(chunk.get()); // copy
|
||||
}
|
||||
|
||||
do_checkpoint = false; // do not checkpoint right after an image chunk
|
||||
has_mtmd = true;
|
||||
}
|
||||
|
||||
// If using an alora, there may be uncached tokens that come
|
||||
// before the invocation sequence. When this happens, the
|
||||
// tokens before the invocation sequence need to be
|
||||
// processed without the adapter in a separate batch, then
|
||||
// the adapter needs to be enabled for the remaining tokens.
|
||||
if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.prompt.n_tokens()) {
|
||||
SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start);
|
||||
const auto & enabled_loras = lora_get_enabled_ids(slot.lora);
|
||||
GGML_ASSERT(enabled_loras.size() == 1);
|
||||
alora_scale = slot.lora[enabled_loras[0]].scale;
|
||||
slot.lora[enabled_loras[0]].scale = 0.0f;
|
||||
alora_disabled_id = enabled_loras[0];
|
||||
}
|
||||
|
||||
// make checkpoints only for completion tasks
|
||||
do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION;
|
||||
|
||||
// make a checkpoint of the parts of the memory that cannot be rolled back.
|
||||
// checkpoints are created only if:
|
||||
// - the model uses SWA and we are not using `swa_full`
|
||||
// - the model architecture is marked as recurrent or hybrid
|
||||
//
|
||||
// TODO: try to make this conditional on the context or the memory module, instead of the model type
|
||||
do_checkpoint = do_checkpoint && (
|
||||
llama_model_is_recurrent(model) ||
|
||||
llama_model_is_hybrid(model) ||
|
||||
(llama_model_n_swa(model) > 0 && !params_base.swa_full)
|
||||
);
|
||||
|
||||
// add prompt tokens for processing in the current batch
|
||||
while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) {
|
||||
// get next token to process
|
||||
|
|
@ -2544,13 +2546,13 @@ private:
|
|||
// - 4 + n_ubatch
|
||||
// - 4
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/20288
|
||||
{
|
||||
if (do_checkpoint) {
|
||||
static const int checkpoint_offsets[] = {4 + n_ubatch, 4};
|
||||
|
||||
bool should_break = false;
|
||||
for (int offset : checkpoint_offsets) {
|
||||
const int n_last = std::min(n_batch, offset);
|
||||
if (do_checkpoint && slot.task->n_tokens() == slot.prompt.n_tokens() + n_last) {
|
||||
if (slot.task->n_tokens() == slot.prompt.n_tokens() + n_last) {
|
||||
should_break = true;
|
||||
break;
|
||||
}
|
||||
|
|
@ -2607,10 +2609,13 @@ private:
|
|||
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
|
||||
|
||||
// no need for empty or small checkpoints
|
||||
do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64);
|
||||
do_checkpoint = do_checkpoint && (pos_min >= 0 && slot.prompt.n_tokens() >= 64);
|
||||
|
||||
// do not checkpoint after mtmd chunks
|
||||
do_checkpoint = do_checkpoint && !has_mtmd;
|
||||
|
||||
// no need to create checkpoints that are too close together
|
||||
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64);
|
||||
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || slot.prompt.n_tokens() - n_tokens_cur > slot.prompt.checkpoints.back().n_tokens + 64);
|
||||
|
||||
// note: we create the checkpoint before calling llama_decode(), so the current batch is not
|
||||
// yet processed and therefore it is not part of the checkpoint.
|
||||
|
|
|
|||
Loading…
Reference in New Issue