llama.cpp/tools/server/server-task.h

522 lines
14 KiB
C++

#pragma once
#include "common.h"
#include "llama.h"
#include <string>
#include <unordered_set>
#include <list>
// TODO: prevent including the whole server-common.h as we only use server_tokens
#include "server-common.h"
using json = nlohmann::ordered_json;
enum server_task_type {
SERVER_TASK_TYPE_COMPLETION,
SERVER_TASK_TYPE_EMBEDDING,
SERVER_TASK_TYPE_RERANK,
SERVER_TASK_TYPE_INFILL,
SERVER_TASK_TYPE_CANCEL,
SERVER_TASK_TYPE_NEXT_RESPONSE,
SERVER_TASK_TYPE_METRICS,
SERVER_TASK_TYPE_SLOT_SAVE,
SERVER_TASK_TYPE_SLOT_RESTORE,
SERVER_TASK_TYPE_SLOT_ERASE,
SERVER_TASK_TYPE_SET_LORA,
};
// TODO: change this to more generic "response_format" to replace the "format_response_*" in server-common
enum task_response_type {
TASK_RESPONSE_TYPE_NONE, // llama.cpp native format
TASK_RESPONSE_TYPE_OAI_CHAT,
TASK_RESPONSE_TYPE_OAI_CMPL,
TASK_RESPONSE_TYPE_OAI_EMBD,
TASK_RESPONSE_TYPE_ANTHROPIC,
};
enum stop_type {
STOP_TYPE_NONE,
STOP_TYPE_EOS,
STOP_TYPE_WORD,
STOP_TYPE_LIMIT,
};
struct task_params {
bool stream = true;
bool include_usage = false;
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
bool return_tokens = false;
bool return_progress = false;
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
int32_t n_predict = -1; // new tokens to predict
int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters
int32_t n_cmpl = 1; // number of completions to generate from this prompt
int32_t n_cache_reuse = 0; // min chunk size to attempt reusing from the cache via KV shifting (0 = disabled)
int64_t t_max_prompt_ms = -1; // TODO: implement
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
std::vector<common_adapter_lora_info> lora;
std::vector<std::string> antiprompt;
std::vector<std::string> response_fields;
bool timings_per_token = false;
bool post_sampling_probs = false;
struct common_params_sampling sampling;
struct common_params_speculative speculative;
// response formatting
bool verbose = false;
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
common_chat_syntax oaicompat_chat_syntax;
// Embeddings
int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) const;
json to_json(bool only_metrics = false) const;
};
struct server_task {
int id = -1; // to be filled by server_queue
int index = -1; // used when there are multiple prompts (batch request)
// used by SERVER_TASK_TYPE_CANCEL
int id_target = -1;
int id_slot = -1;
// used by parallel sampling (multiple completions from same prompt)
size_t n_children = 0; // number of tasks reusing this prompt
int id_parent = -1;
// used by SERVER_TASK_TYPE_INFERENCE
task_params params;
server_tokens tokens;
server_task_type type;
// used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE
struct slot_action {
int slot_id;
std::string filename;
std::string filepath;
};
slot_action slot_action;
// used by SERVER_TASK_TYPE_METRICS
bool metrics_reset_bucket = false;
// used by SERVER_TASK_TYPE_SET_LORA
std::vector<common_adapter_lora_info> set_lora;
server_task() = default;
server_task(server_task_type type) : type(type) {}
int32_t n_tokens() const {
return tokens.size();
}
static task_params params_from_json_cmpl(
const llama_context * ctx,
const common_params & params_base,
const json & data);
// utility function
static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
std::unordered_set<int> ids(tasks.size());
for (size_t i = 0; i < tasks.size(); i++) {
ids.insert(tasks[i].id);
}
return ids;
}
server_task create_child(int id_parent, int id_child, int idx) const {
server_task copy;
copy.id = id_child;
copy.index = idx;
copy.id_parent = id_parent;
copy.params = params;
copy.type = type;
copy.tokens = tokens.clone();
return copy;
}
};
struct result_timings {
int32_t cache_n = -1;
int32_t prompt_n = -1;
double prompt_ms;
double prompt_per_token_ms;
double prompt_per_second;
int32_t predicted_n = -1;
double predicted_ms;
double predicted_per_token_ms;
double predicted_per_second;
// Optional speculative metrics - only included when > 0
int32_t draft_n = 0;
int32_t draft_n_accepted = 0;
json to_json() const;
};
struct result_prompt_progress {
int32_t total = 0;
int32_t cache = 0;
int32_t processed = 0;
int64_t time_ms = 0;
json to_json() const;
};
// struct for tracking the state of a task (e.g., for streaming)
struct task_result_state {
// tracking diffs for partial tool calls
std::vector<common_chat_msg_diff> diffs;
common_chat_syntax oaicompat_chat_syntax;
common_chat_msg chat_msg;
std::string generated_text; // append new chunks of generated text here
std::vector<std::string> generated_tool_call_ids;
task_result_state(const common_chat_syntax & oaicompat_chat_syntax)
: oaicompat_chat_syntax(oaicompat_chat_syntax) {}
// parse partial tool calls and update the internal state
common_chat_msg update_chat_msg(
const std::string & text_added,
bool is_partial,
std::vector<common_chat_msg_diff> & diffs);
};
struct server_task_result {
int id = -1;
int id_slot = -1;
virtual bool is_error() {
// only used by server_task_result_error
return false;
}
virtual bool is_stop() {
// only used by server_task_result_cmpl_*
return true;
}
virtual int get_index() {
return -1;
}
virtual void update(task_result_state &) {
// only used by server_task_result_cmpl_*
}
virtual json to_json() = 0;
virtual ~server_task_result() = default;
};
// using shared_ptr for polymorphism of server_task_result
using server_task_result_ptr = std::unique_ptr<server_task_result>;
struct completion_token_output {
llama_token tok;
float prob;
std::string text_to_send;
struct prob_info {
llama_token tok;
std::string txt;
float prob;
};
std::vector<prob_info> probs;
json to_json(bool post_sampling_probs) const;
static json probs_vector_to_json(const std::vector<completion_token_output> & probs, bool post_sampling_probs);
static float logarithm(float x);
static std::vector<unsigned char> str_to_bytes(const std::string & str);
};
struct server_task_result_cmpl_final : server_task_result {
int index = 0;
std::string content;
llama_tokens tokens;
bool stream;
bool include_usage;
result_timings timings;
std::string prompt;
bool truncated;
int32_t n_decoded;
int32_t n_prompt_tokens;
int32_t n_tokens_cached;
bool has_new_line;
std::string stopping_word;
stop_type stop = STOP_TYPE_NONE;
bool post_sampling_probs;
std::vector<completion_token_output> probs_output;
std::vector<std::string> response_fields;
task_params generation_params;
// response formatting
bool verbose = false;
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
common_chat_msg oaicompat_msg; // to be populated by update()
std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
bool is_updated = false;
virtual int get_index() override {
return index;
}
virtual bool is_stop() override {
return true; // in stream mode, final responses are considered stop
}
virtual json to_json() override;
virtual void update(task_result_state & state) override {
is_updated = true;
oaicompat_msg = state.update_chat_msg(content, false, oaicompat_msg_diffs);
}
json to_json_non_oaicompat();
json to_json_oaicompat();
json to_json_oaicompat_chat();
json to_json_oaicompat_chat_stream();
json to_json_anthropic();
json to_json_anthropic_stream();
};
struct server_task_result_cmpl_partial : server_task_result {
int index = 0;
std::string content;
llama_tokens tokens;
int32_t n_decoded;
int32_t n_prompt_tokens;
bool post_sampling_probs;
bool is_progress = false;
completion_token_output prob_output;
result_timings timings;
result_prompt_progress progress;
// response formatting
bool verbose = false;
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
std::vector<common_chat_msg_diff> oaicompat_msg_diffs; // to be populated by update()
bool is_updated = false;
virtual int get_index() override {
return index;
}
virtual bool is_stop() override {
return false; // in stream mode, partial responses are not considered stop
}
virtual json to_json() override;
virtual void update(task_result_state & state) override {
is_updated = true;
state.update_chat_msg(content, true, oaicompat_msg_diffs);
}
json to_json_non_oaicompat();
json to_json_oaicompat();
json to_json_oaicompat_chat();
json to_json_anthropic();
};
struct server_task_result_embd : server_task_result {
int index = 0;
std::vector<std::vector<float>> embedding;
int32_t n_tokens;
// response formatting
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
virtual int get_index() override {
return index;
}
virtual json to_json() override;
json to_json_non_oaicompat();
json to_json_oaicompat();
};
struct server_task_result_rerank : server_task_result {
int index = 0;
float score = -1e6;
int32_t n_tokens;
virtual int get_index() override {
return index;
}
virtual json to_json() override;
};
struct server_task_result_error : server_task_result {
int index = 0;
error_type err_type = ERROR_TYPE_SERVER;
std::string err_msg;
// for ERROR_TYPE_EXCEED_CONTEXT_SIZE
int32_t n_prompt_tokens = 0;
int32_t n_ctx = 0;
virtual bool is_error() override {
return true;
}
virtual json to_json() override;
};
struct server_task_result_metrics : server_task_result {
int n_idle_slots;
int n_processing_slots;
int n_tasks_deferred;
int64_t t_start;
// TODO: somehow reuse server_metrics in the future, instead of duplicating the fields
uint64_t n_prompt_tokens_processed_total = 0;
uint64_t t_prompt_processing_total = 0;
uint64_t n_tokens_predicted_total = 0;
uint64_t t_tokens_generation_total = 0;
uint64_t n_tokens_max = 0;
uint64_t n_prompt_tokens_processed = 0;
uint64_t t_prompt_processing = 0;
uint64_t n_tokens_predicted = 0;
uint64_t t_tokens_generation = 0;
uint64_t n_decode_total = 0;
uint64_t n_busy_slots_total = 0;
// while we can also use std::vector<server_slot> this requires copying the slot object which can be quite messy
// therefore, we use json to temporarily store the slot.to_json() result
json slots_data = json::array();
virtual json to_json() override;
};
struct server_task_result_slot_save_load : server_task_result {
std::string filename;
bool is_save; // true = save, false = load
size_t n_tokens;
size_t n_bytes;
double t_ms;
virtual json to_json() override;
};
struct server_task_result_slot_erase : server_task_result {
size_t n_erased;
virtual json to_json() override;
};
struct server_task_result_apply_lora : server_task_result {
virtual json to_json() override;
};
struct server_prompt_checkpoint {
llama_pos pos_min;
llama_pos pos_max;
std::vector<uint8_t> data;
size_t size() const {
return data.size();
}
};
struct server_prompt {
server_tokens tokens;
std::vector<uint8_t> data;
std::list<server_prompt_checkpoint> checkpoints;
size_t size() const {
size_t res = data.size();
for (const auto & checkpoint : checkpoints) {
res += checkpoint.size();
}
return res;
}
int n_tokens() const {
return tokens.size();
}
server_prompt clone() const {
return server_prompt {
tokens.clone(),
data,
checkpoints
};
}
};
struct server_prompt_cache {
server_prompt_cache(int32_t limit_size_mib, size_t limit_tokens) {
this->limit_size = 1024ull*1024ull*(limit_size_mib < 0 ? 0 : limit_size_mib);
this->limit_tokens = limit_tokens;
}
std::list<server_prompt> states;
// in bytes, 0 = no limit
size_t limit_size = 0;
// in tokens, 0 = no limit
size_t limit_tokens = 0;
size_t size() const;
size_t n_tokens() const;
server_prompt * alloc(const server_prompt & prompt, size_t state_size);
bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot);
void update();
};