diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 324c3af30c..af6e053424 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -79,6 +79,8 @@ struct server_slot { common_speculative * spec = nullptr; + // TODO: move members that belong to the task (such as `generated_text`, `has_new_line`) to task_results_state + // see https://github.com/ggml-org/llama.cpp/pull/18283#issuecomment-3710175837 std::unique_ptr task; std::unique_ptr task_prev; // used for debugging @@ -153,7 +155,7 @@ struct server_slot { common_sampler_ptr smpl; - llama_token sampled; // in speculative mode, this is the last accepted token + llama_token sampled; // in speculative mode, this is the last accepted token llama_tokens drafted; // stats @@ -201,12 +203,46 @@ struct server_slot { alora_invocation_start = -1; } + // remove cached prompt + tokens + void clear(bool allow_processing) { + if (!allow_processing) { + GGML_ASSERT(!is_processing()); + } + + SLT_INF(*this, "clearing slot with %zu tokens\n", prompt.tokens.size()); + + llama_memory_seq_rm(llama_get_memory(ctx), id, -1, -1); + prompt.tokens.clear(); + } + + void init_sampler() const { + const int64_t t_start = ggml_time_us(); + + common_sampler_reset(smpl.get()); + + int n_text = 0; + + for (int i = 0; i < (int) prompt.tokens.size(); i++) { + const llama_token id = prompt.tokens[i]; + + if (id != LLAMA_TOKEN_NULL) { + common_sampler_accept(smpl.get(), id, false); + n_text++; + } + } + + SLT_INF(*this, "init sampler, took %0.2f ms, tokens: text = %d, total = %d\n", + (ggml_time_us() - t_start) / 1000.0, n_text, (int) prompt.tokens.size()); + } + + // TODO: move to server_task bool need_embd() const { GGML_ASSERT(task); return server_task_type_need_embd(task->type); } + // TODO: move to server_task bool need_logits() const { GGML_ASSERT(task); @@ -258,10 +294,13 @@ struct server_slot { SLT_WRN(*this, "%s", "slot is not processing\n"); return; } + generated_token_probs.push_back(token); } int get_n_draft_max() const { + GGML_ASSERT(task); + if (!can_speculate()) { return 0; } @@ -287,12 +326,14 @@ struct server_slot { } // note: a slot can also be either a parent or a child + // TODO: move to server_task bool is_parent() const { - return is_processing() && task->n_children > 0; + return task->n_children > 0; } + // TODO: move to server_task bool is_child() const { - return is_processing() && task->id_parent >= 0; + return task->id_parent >= 0; } void release() { @@ -301,10 +342,16 @@ struct server_slot { SLT_INF(*this, "stop processing: n_tokens = %d, truncated = %d\n", prompt.n_tokens(), truncated); - t_last_used = ggml_time_us(); + t_last_used = ggml_time_us(); t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; + state = SLOT_STATE_IDLE; + // do not keep context of the child slots - the parent's context is enough + if (is_child()) { + clear(false); + } + task_prev = std::move(task); task.reset(); @@ -425,14 +472,22 @@ struct server_slot { } void copy_state_to(server_slot & other) const { - llama_memory_seq_rm(llama_get_memory(ctx), other.id, 0, -1); - llama_memory_seq_cp(llama_get_memory(ctx), id, other.id, 0, -1); + GGML_ASSERT(state == SLOT_STATE_DONE_PROMPT); + + llama_memory_seq_rm(llama_get_memory(ctx), other.id, -1, -1); + llama_memory_seq_cp(llama_get_memory(ctx), id, other.id, -1, -1); + other.n_decoded = n_decoded; other.n_remaining = n_remaining; other.i_batch = i_batch; + + other.t_start_process_prompt = t_start_process_prompt; + other.t_prompt_processing = t_prompt_processing; other.n_prompt_tokens_cache = n_prompt_tokens_cache; other.n_prompt_tokens_processed = n_prompt_tokens_processed; + other.prompt = prompt.clone(); + other.init_sampler(); } }; @@ -745,6 +800,7 @@ private: } slots.clear(); + for (int i = 0; i < params_base.n_parallel; i++) { server_slot slot; @@ -993,7 +1049,7 @@ private: ret->prompt_save(*prompt_cache); if (!ret->prompt_load(*prompt_cache, task.tokens)) { - clear_slot(*ret); + ret->clear(false); } prompt_cache->update(); @@ -1005,17 +1061,6 @@ private: return ret; } - void clear_slot(server_slot & slot, bool allow_processing = false) const { - if (!allow_processing) { - GGML_ASSERT(!slot.is_processing()); - } - - SLT_WRN(slot, "clearing slot with %zu tokens\n", slot.prompt.tokens.size()); - - llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1); - slot.prompt.tokens.clear(); - } - // return true if at least one slot has been cleared // TODO: improve logic // - smarter decision which slot to clear (LRU or longest prompt?) @@ -1036,7 +1081,7 @@ private: if (slot.prompt.n_tokens() > 0) { SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size()); - clear_slot(slot); + slot.clear(false); res = true; @@ -1182,7 +1227,7 @@ private: ? SLOT_STATE_WAIT_OTHER // wait for the parent to process prompt : SLOT_STATE_STARTED; - SLT_INF(slot, "%s", "processing task\n"); + SLT_INF(slot, "processing task, is_child = %d\n", slot.is_child()); return true; } @@ -1819,7 +1864,7 @@ private: // Erase token cache const size_t n_erased = slot->prompt.tokens.size(); - clear_slot(*slot); + slot->clear(false); auto res = std::make_unique(); res->id = task.id; @@ -2053,8 +2098,29 @@ private: continue; } + // check if this is a child slot + if (slot.state == SLOT_STATE_WAIT_OTHER) { + SLT_DBG(slot, "%s", "waiting for parent slot to complete\n"); + continue; + } + // this slot still has a prompt to be processed if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { + // wait for all children to be launched + if (slot.is_parent()) { + int n_launched = 0; + for (auto & other : slots) { + if (other.is_processing() && other.is_child() && other.task->id_parent == slot.task->id) { + ++n_launched; + } + } + + if (n_launched < slot.task->n_children) { + SLT_DBG(slot, "waiting for children to be launched, n_children = %d, n_launched = %d\n", slot.task->n_children, n_launched); + continue; + } + } + const auto & input_tokens = slot.task->tokens; // TODO: maybe move branch to outside of this loop in the future @@ -2355,7 +2421,7 @@ private: if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) { SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0); - clear_slot(slot, /*allow_processing=*/true); + slot.clear(true); // there is no common part left slot.n_prompt_tokens_cache = 0; @@ -2455,16 +2521,6 @@ private: GGML_ASSERT(batch.n_tokens > 0); - common_sampler_reset(slot.smpl.get()); - - // Process all prompt tokens through sampler system - for (int i = 0; i < slot.task->n_tokens(); ++i) { - llama_token id = input_tokens[i]; - if (id != LLAMA_TOKEN_NULL) { - common_sampler_accept(slot.smpl.get(), id, false); - } - } - // extract the logits only for the last token batch.logits[batch.n_tokens - 1] = true; @@ -2473,6 +2529,8 @@ private: SLT_INF(slot, "prompt done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens); + slot.init_sampler(); + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id); @@ -2519,11 +2577,6 @@ private: } } - if (batch.n_tokens == 0) { - SRV_WRN("%s", "no tokens to decode\n"); - return; - } - SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); if (slot_batched) { @@ -2540,6 +2593,10 @@ private: llama_set_embeddings(ctx, slot_batched->need_embd()); } + if (batch.n_tokens == 0) { + SRV_WRN("%s", "no tokens to decode\n"); + } + int32_t i_next = 0; // process the created batch of tokens @@ -2591,7 +2648,7 @@ private: // note: it's complicated to keep track of how much of the current batch has been // processed before the error occurred, so we simply clear the entire context - clear_slot(slot); + slot.clear(false); } } @@ -2615,27 +2672,34 @@ private: // on successful decode, restore the original batch size n_batch = llama_n_batch(ctx); + // handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too for (auto & slot : slots) { - // may need to copy state to other slots if (slot.state == SLOT_STATE_DONE_PROMPT && slot.is_parent()) { - std::vector child_slots; + SLT_INF(slot, "parent task prompt done, n_children = %d\n", slot.task->n_children); + + std::vector children; for (auto & other : slots) { if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) { - child_slots.push_back(&other); + children.push_back(&other); } } // we can only proceed if all child slots are having the correct tasks - if (child_slots.size() == slot.task->n_children) { + if (slot.task->n_children == (int) children.size()) { // copy state to the child slots - for (auto & child : child_slots) { - SLT_INF(slot, "copying state to child %d\n", child->id); + for (auto & child : children) { + SLT_INF(slot, " - copying state to child %d\n", child->id); + + GGML_ASSERT(child->state == SLOT_STATE_WAIT_OTHER); + slot.copy_state_to(*child); child->state = SLOT_STATE_DONE_PROMPT; } } } + } + for (auto & slot : slots) { // optionally send prompt processing progress if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) { if (slot.task->params.stream && slot.task->params.return_progress) { @@ -2720,7 +2784,7 @@ private: continue; } - size_t n_draft = slot.drafted.size(); + const size_t n_draft = slot.drafted.size(); // the accepted tokens from the speculation const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted); @@ -2923,9 +2987,11 @@ std::unique_ptr server_routes::handle_completions_impl( task.params.oaicompat_cmpl_id = completion_id; task.params.oaicompat_model = meta->model_name; + // prepare child tasks if (task.params.n_cmpl > 1) { task.n_children = task.params.n_cmpl - 1; - for (size_t j = 0; j < task.n_children; j++) { + + for (int j = 0; j < task.n_children; j++) { server_task child = task.create_child(task.id, rd.get_new_id()); // use different sampling seed for each child @@ -2938,7 +3004,8 @@ std::unique_ptr server_routes::handle_completions_impl( } } - tasks.push_back(std::move(task)); + // note: the parent task always launches first + tasks.insert(tasks.begin(), std::move(task)); } rd.post_tasks(std::move(tasks)); diff --git a/tools/server/server-task.h b/tools/server/server-task.h index ead1491182..cf08fced63 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -121,8 +121,8 @@ struct server_task { int id_slot = -1; // used by parallel sampling (multiple completions from same prompt) - size_t n_children = 0; // number of tasks reusing this prompt - int id_parent = -1; + int n_children = 0; // number of tasks reusing this prompt + int id_parent = -1; // used by SERVER_TASK_TYPE_INFERENCE task_params params; @@ -173,11 +173,13 @@ struct server_task { server_task create_child(int id_parent, int id_child) const { server_task copy; + copy.id = id_child; copy.id_parent = id_parent; copy.params = params; copy.type = type; copy.tokens = tokens.clone(); + return copy; }