server : improve mtmd ctx checkpoints
This commit is contained in:
parent
21c8045214
commit
6051df2f2b
|
|
@ -2459,8 +2459,39 @@ private:
|
||||||
slot.n_prompt_tokens_cache = 0;
|
slot.n_prompt_tokens_cache = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If using an alora, there may be uncached tokens that come
|
||||||
|
// before the invocation sequence. When this happens, the
|
||||||
|
// tokens before the invocation sequence need to be
|
||||||
|
// processed without the adapter in a separate batch, then
|
||||||
|
// the adapter needs to be enabled for the remaining tokens.
|
||||||
|
if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.prompt.n_tokens()) {
|
||||||
|
SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start);
|
||||||
|
const auto & enabled_loras = lora_get_enabled_ids(slot.lora);
|
||||||
|
GGML_ASSERT(enabled_loras.size() == 1);
|
||||||
|
alora_scale = slot.lora[enabled_loras[0]].scale;
|
||||||
|
slot.lora[enabled_loras[0]].scale = 0.0f;
|
||||||
|
alora_disabled_id = enabled_loras[0];
|
||||||
|
}
|
||||||
|
|
||||||
bool do_checkpoint = params_base.n_ctx_checkpoints > 0;
|
bool do_checkpoint = params_base.n_ctx_checkpoints > 0;
|
||||||
|
|
||||||
|
// make checkpoints only for completion tasks
|
||||||
|
do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION;
|
||||||
|
|
||||||
|
// make a checkpoint of the parts of the memory that cannot be rolled back.
|
||||||
|
// checkpoints are created only if:
|
||||||
|
// - the model uses SWA and we are not using `swa_full`
|
||||||
|
// - the model architecture is marked as recurrent or hybrid
|
||||||
|
//
|
||||||
|
// TODO: try to make this conditional on the context or the memory module, instead of the model type
|
||||||
|
do_checkpoint = do_checkpoint && (
|
||||||
|
llama_model_is_recurrent(model) ||
|
||||||
|
llama_model_is_hybrid(model) ||
|
||||||
|
(llama_model_n_swa(model) > 0 && !params_base.swa_full)
|
||||||
|
);
|
||||||
|
|
||||||
|
bool has_mtmd = false;
|
||||||
|
|
||||||
// check if we should process the image
|
// check if we should process the image
|
||||||
if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
|
if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) {
|
||||||
// process the image
|
// process the image
|
||||||
|
|
@ -2481,38 +2512,9 @@ private:
|
||||||
slot.prompt.tokens.push_back(chunk.get()); // copy
|
slot.prompt.tokens.push_back(chunk.get()); // copy
|
||||||
}
|
}
|
||||||
|
|
||||||
do_checkpoint = false; // do not checkpoint right after an image chunk
|
has_mtmd = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If using an alora, there may be uncached tokens that come
|
|
||||||
// before the invocation sequence. When this happens, the
|
|
||||||
// tokens before the invocation sequence need to be
|
|
||||||
// processed without the adapter in a separate batch, then
|
|
||||||
// the adapter needs to be enabled for the remaining tokens.
|
|
||||||
if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.prompt.n_tokens()) {
|
|
||||||
SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start);
|
|
||||||
const auto & enabled_loras = lora_get_enabled_ids(slot.lora);
|
|
||||||
GGML_ASSERT(enabled_loras.size() == 1);
|
|
||||||
alora_scale = slot.lora[enabled_loras[0]].scale;
|
|
||||||
slot.lora[enabled_loras[0]].scale = 0.0f;
|
|
||||||
alora_disabled_id = enabled_loras[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
// make checkpoints only for completion tasks
|
|
||||||
do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION;
|
|
||||||
|
|
||||||
// make a checkpoint of the parts of the memory that cannot be rolled back.
|
|
||||||
// checkpoints are created only if:
|
|
||||||
// - the model uses SWA and we are not using `swa_full`
|
|
||||||
// - the model architecture is marked as recurrent or hybrid
|
|
||||||
//
|
|
||||||
// TODO: try to make this conditional on the context or the memory module, instead of the model type
|
|
||||||
do_checkpoint = do_checkpoint && (
|
|
||||||
llama_model_is_recurrent(model) ||
|
|
||||||
llama_model_is_hybrid(model) ||
|
|
||||||
(llama_model_n_swa(model) > 0 && !params_base.swa_full)
|
|
||||||
);
|
|
||||||
|
|
||||||
// add prompt tokens for processing in the current batch
|
// add prompt tokens for processing in the current batch
|
||||||
while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) {
|
while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) {
|
||||||
// get next token to process
|
// get next token to process
|
||||||
|
|
@ -2544,13 +2546,13 @@ private:
|
||||||
// - 4 + n_ubatch
|
// - 4 + n_ubatch
|
||||||
// - 4
|
// - 4
|
||||||
// ref: https://github.com/ggml-org/llama.cpp/pull/20288
|
// ref: https://github.com/ggml-org/llama.cpp/pull/20288
|
||||||
{
|
if (do_checkpoint) {
|
||||||
static const int checkpoint_offsets[] = {4 + n_ubatch, 4};
|
static const int checkpoint_offsets[] = {4 + n_ubatch, 4};
|
||||||
|
|
||||||
bool should_break = false;
|
bool should_break = false;
|
||||||
for (int offset : checkpoint_offsets) {
|
for (int offset : checkpoint_offsets) {
|
||||||
const int n_last = std::min(n_batch, offset);
|
const int n_last = std::min(n_batch, offset);
|
||||||
if (do_checkpoint && slot.task->n_tokens() == slot.prompt.n_tokens() + n_last) {
|
if (slot.task->n_tokens() == slot.prompt.n_tokens() + n_last) {
|
||||||
should_break = true;
|
should_break = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
@ -2607,10 +2609,13 @@ private:
|
||||||
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
|
const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id);
|
||||||
|
|
||||||
// no need for empty or small checkpoints
|
// no need for empty or small checkpoints
|
||||||
do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64);
|
do_checkpoint = do_checkpoint && (pos_min >= 0 && slot.prompt.n_tokens() >= 64);
|
||||||
|
|
||||||
|
// do not checkpoint after mtmd chunks
|
||||||
|
do_checkpoint = do_checkpoint && !has_mtmd;
|
||||||
|
|
||||||
// no need to create checkpoints that are too close together
|
// no need to create checkpoints that are too close together
|
||||||
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64);
|
do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || slot.prompt.n_tokens() - n_tokens_cur > slot.prompt.checkpoints.back().n_tokens + 64);
|
||||||
|
|
||||||
// note: we create the checkpoint before calling llama_decode(), so the current batch is not
|
// note: we create the checkpoint before calling llama_decode(), so the current batch is not
|
||||||
// yet processed and therefore it is not part of the checkpoint.
|
// yet processed and therefore it is not part of the checkpoint.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue