server : task declares needs (embd, logits, sampling)
This commit is contained in:
parent
b579b970b4
commit
ffa0d15e86
|
|
@ -45,26 +45,6 @@ enum server_state {
|
|||
SERVER_STATE_READY, // Server is ready and model is loaded
|
||||
};
|
||||
|
||||
static bool server_task_type_need_embd(server_task_type task_type) {
|
||||
switch (task_type) {
|
||||
case SERVER_TASK_TYPE_EMBEDDING:
|
||||
case SERVER_TASK_TYPE_RERANK:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
static bool server_task_type_need_logits(server_task_type task_type) {
|
||||
switch (task_type) {
|
||||
case SERVER_TASK_TYPE_COMPLETION:
|
||||
case SERVER_TASK_TYPE_INFILL:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
struct server_slot {
|
||||
int id;
|
||||
|
||||
|
|
@ -235,25 +215,13 @@ struct server_slot {
|
|||
(ggml_time_us() - t_start) / 1000.0, n_text, (int) prompt.tokens.size());
|
||||
}
|
||||
|
||||
// TODO: move to server_task
|
||||
bool need_embd() const {
|
||||
GGML_ASSERT(task);
|
||||
|
||||
return server_task_type_need_embd(task->type);
|
||||
}
|
||||
|
||||
// TODO: move to server_task
|
||||
bool need_logits() const {
|
||||
GGML_ASSERT(task);
|
||||
|
||||
return server_task_type_need_logits(task->type);
|
||||
}
|
||||
|
||||
// if the context does not have a memory module then all embeddings have to be computed within a single ubatch
|
||||
// also we cannot split if the pooling would require any past tokens
|
||||
bool can_split() const {
|
||||
GGML_ASSERT(task);
|
||||
|
||||
return
|
||||
!need_embd() ||
|
||||
!task->need_embd() ||
|
||||
(llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST);
|
||||
}
|
||||
|
||||
|
|
@ -1182,7 +1150,7 @@ private:
|
|||
SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
|
||||
|
||||
// initialize samplers
|
||||
if (task.uses_sampling()) {
|
||||
if (task.need_sampling()) {
|
||||
slot.smpl.reset(common_sampler_init(model, task.params.sampling));
|
||||
|
||||
if (slot.smpl == nullptr) {
|
||||
|
|
@ -2163,7 +2131,7 @@ private:
|
|||
}
|
||||
|
||||
// TODO: support memory-less logits computation
|
||||
if (slot.need_logits() && !llama_get_memory(ctx)) {
|
||||
if (slot.task->need_logits() && !llama_get_memory(ctx)) {
|
||||
send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER);
|
||||
slot.release();
|
||||
continue;
|
||||
|
|
@ -2502,7 +2470,7 @@ private:
|
|||
cur_tok,
|
||||
slot.prompt.tokens.pos_next(),
|
||||
{ slot.id },
|
||||
slot.need_embd());
|
||||
slot.task->need_embd());
|
||||
slot.prompt.tokens.push_back(cur_tok);
|
||||
|
||||
slot.n_prompt_tokens_processed++;
|
||||
|
|
@ -2592,7 +2560,7 @@ private:
|
|||
slot_batched->lora[alora_disabled_id].scale = alora_scale;
|
||||
}
|
||||
|
||||
llama_set_embeddings(ctx, slot_batched->need_embd());
|
||||
llama_set_embeddings(ctx, slot_batched->task->need_embd());
|
||||
}
|
||||
|
||||
for (auto & slot : slots) {
|
||||
|
|
@ -2735,7 +2703,7 @@ private:
|
|||
continue; // continue loop of slots
|
||||
}
|
||||
|
||||
GGML_ASSERT(slot.task->uses_sampling());
|
||||
GGML_ASSERT(slot.task->need_sampling());
|
||||
|
||||
// prompt evaluated for next-token prediction
|
||||
slot.state = SLOT_STATE_GENERATING;
|
||||
|
|
|
|||
|
|
@ -156,9 +156,34 @@ struct server_task {
|
|||
return tokens.size();
|
||||
}
|
||||
|
||||
bool uses_sampling() const {
|
||||
return type != SERVER_TASK_TYPE_EMBEDDING &&
|
||||
type != SERVER_TASK_TYPE_RERANK;
|
||||
bool need_embd() const {
|
||||
switch (type) {
|
||||
case SERVER_TASK_TYPE_EMBEDDING:
|
||||
case SERVER_TASK_TYPE_RERANK:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool need_logits() const {
|
||||
switch (type) {
|
||||
case SERVER_TASK_TYPE_COMPLETION:
|
||||
case SERVER_TASK_TYPE_INFILL:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool need_sampling() const {
|
||||
switch (type) {
|
||||
case SERVER_TASK_TYPE_COMPLETION:
|
||||
case SERVER_TASK_TYPE_INFILL:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
static task_params params_from_json_cmpl(
|
||||
|
|
|
|||
Loading…
Reference in New Issue