feat: add token healing support
This commit is contained in:
parent
41ea26144e
commit
19b67ed609
|
|
@ -311,6 +311,7 @@ int main(int argc, char ** argv) {
|
||||||
} else {
|
} else {
|
||||||
// otherwise use the prompt as is
|
// otherwise use the prompt as is
|
||||||
prompt = params.prompt;
|
prompt = params.prompt;
|
||||||
|
prompt += "; ignore; the capital of France is:";
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.interactive_first || !prompt.empty() || session_tokens.empty()) {
|
if (params.interactive_first || !prompt.empty() || session_tokens.empty()) {
|
||||||
|
|
|
||||||
|
|
@ -84,6 +84,14 @@ std::string gen_tool_call_id() {
|
||||||
return random_string();
|
return random_string();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string ltrim(std::string s) {
|
||||||
|
s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) {
|
||||||
|
return !std::isspace(ch) && !std::ispunct(ch);
|
||||||
|
}));
|
||||||
|
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// lora utils
|
// lora utils
|
||||||
//
|
//
|
||||||
|
|
|
||||||
|
|
@ -91,6 +91,7 @@ json format_error_response(const std::string & message, const enum error_type ty
|
||||||
std::string random_string();
|
std::string random_string();
|
||||||
std::string gen_chatcmplid();
|
std::string gen_chatcmplid();
|
||||||
std::string gen_tool_call_id();
|
std::string gen_tool_call_id();
|
||||||
|
std::string ltrim(std::string s);
|
||||||
|
|
||||||
//
|
//
|
||||||
// lora utils
|
// lora utils
|
||||||
|
|
@ -196,6 +197,10 @@ public:
|
||||||
tokens.clear();
|
tokens.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_token back() const { return tokens.back(); }
|
||||||
|
|
||||||
|
void pop_back() { tokens.pop_back(); }
|
||||||
|
|
||||||
void keep_first(size_t n);
|
void keep_first(size_t n);
|
||||||
|
|
||||||
std::string detokenize(const llama_context * ctx, bool special) const;
|
std::string detokenize(const llama_context * ctx, bool special) const;
|
||||||
|
|
|
||||||
|
|
@ -1,21 +1,23 @@
|
||||||
#include "server-context.h"
|
#include "server-context.h"
|
||||||
#include "server-common.h"
|
|
||||||
#include "server-http.h"
|
|
||||||
#include "server-task.h"
|
|
||||||
#include "server-queue.h"
|
|
||||||
|
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "log.h"
|
#include "log.h"
|
||||||
#include "sampling.h"
|
|
||||||
#include "speculative.h"
|
|
||||||
#include "mtmd.h"
|
|
||||||
#include "mtmd-helper.h"
|
#include "mtmd-helper.h"
|
||||||
|
#include "mtmd.h"
|
||||||
|
#include "sampling.h"
|
||||||
|
#include "server-common.h"
|
||||||
|
#include "server-http.h"
|
||||||
|
#include "server-queue.h"
|
||||||
|
#include "server-task.h"
|
||||||
|
#include "speculative.h"
|
||||||
|
|
||||||
|
#include <bsm/audit.h>
|
||||||
|
|
||||||
#include <cstddef>
|
|
||||||
#include <cinttypes>
|
#include <cinttypes>
|
||||||
#include <memory>
|
#include <cstddef>
|
||||||
#include <filesystem>
|
#include <filesystem>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
// fix problem with std::min and std::max
|
// fix problem with std::min and std::max
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
|
|
@ -1123,6 +1125,13 @@ private:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (task.params.n_token_healing_enabled) {
|
||||||
|
task.token_healing_params.healing_token = task.tokens.back();
|
||||||
|
task.token_healing_params.healing_token_text = ltrim( common_token_to_piece(ctx, task.token_healing_params.healing_token));
|
||||||
|
task.tokens.pop_back();
|
||||||
|
SLT_DBG(slot, "Token healing enabled, removed last token: %d ('%s')\n",task.token_healing_params.healing_token, task.token_healing_params.healing_token_text.c_str());
|
||||||
|
}
|
||||||
|
|
||||||
if (!task.tokens.validate(ctx)) {
|
if (!task.tokens.validate(ctx)) {
|
||||||
send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST);
|
send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST);
|
||||||
return false;
|
return false;
|
||||||
|
|
@ -2739,6 +2748,41 @@ private:
|
||||||
|
|
||||||
llama_token id = common_sampler_sample(slot.smpl.get(), ctx, tok_idx);
|
llama_token id = common_sampler_sample(slot.smpl.get(), ctx, tok_idx);
|
||||||
|
|
||||||
|
if (slot.task->params.n_token_healing_enabled && !slot.task->token_healing_result.token_healing_complete) {
|
||||||
|
//need to compare sampled token with the prefix of the healing token
|
||||||
|
bool continue_sampling = true;
|
||||||
|
int num_attempts = 0;
|
||||||
|
|
||||||
|
while (continue_sampling) {
|
||||||
|
const std::string token_text = ltrim(common_token_to_piece(ctx, id));
|
||||||
|
|
||||||
|
if (token_text.empty()) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool matched_healing_token = false;
|
||||||
|
|
||||||
|
size_t matched_position = token_text.find(slot.task->token_healing_params.healing_token_text);
|
||||||
|
|
||||||
|
if (matched_position != std::string::npos) {
|
||||||
|
matched_healing_token = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!matched_healing_token && num_attempts < slot.task->params.n_token_healing_max_retries) {
|
||||||
|
id = common_sampler_sample(slot.smpl.get(), ctx, tok_idx);
|
||||||
|
num_attempts ++;
|
||||||
|
} else {
|
||||||
|
if (matched_healing_token) {
|
||||||
|
//If token healing matched, we only want to return a substring AFTER the match.
|
||||||
|
slot.task->token_healing_result.accepted_healing_token_text = token_text.substr(matched_position + slot.task->token_healing_params.healing_token_text.length());
|
||||||
|
} else {
|
||||||
|
slot.task->token_healing_result.accepted_healing_token_text = token_text;
|
||||||
|
}
|
||||||
|
continue_sampling = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
slot.i_batch = -1;
|
slot.i_batch = -1;
|
||||||
|
|
||||||
common_sampler_accept(slot.smpl.get(), id, true);
|
common_sampler_accept(slot.smpl.get(), id, true);
|
||||||
|
|
@ -2758,7 +2802,16 @@ private:
|
||||||
|
|
||||||
completion_token_output result;
|
completion_token_output result;
|
||||||
result.tok = id;
|
result.tok = id;
|
||||||
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
|
|
||||||
|
if (slot.task->params.n_token_healing_enabled && !slot.task->token_healing_result.token_healing_complete) {
|
||||||
|
if (!slot.task->token_healing_result.accepted_healing_token_text.empty()) {
|
||||||
|
result.text_to_send = slot.task->token_healing_result.accepted_healing_token_text;
|
||||||
|
slot.task->token_healing_result.token_healing_complete = true;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
|
||||||
|
}
|
||||||
|
|
||||||
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
|
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
|
||||||
|
|
||||||
if (slot.task->params.sampling.n_probs > 0) {
|
if (slot.task->params.sampling.n_probs > 0) {
|
||||||
|
|
@ -2954,6 +3007,8 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
|
||||||
std::vector<server_task> tasks;
|
std::vector<server_task> tasks;
|
||||||
|
|
||||||
const auto & prompt = data.at("prompt");
|
const auto & prompt = data.at("prompt");
|
||||||
|
|
||||||
|
SRV_INF("\n\nYOOO (UPDATED) this is the System Prompt: %s\n\n", prompt.get<std::string>().c_str());
|
||||||
// TODO: this log can become very long, put it behind a flag or think about a more compact format
|
// TODO: this log can become very long, put it behind a flag or think about a more compact format
|
||||||
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
|
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -509,6 +509,9 @@ task_params server_task::params_from_json_cmpl(
|
||||||
throw std::runtime_error("n_cmpl cannot be greater than the number of slots, please increase -np");
|
throw std::runtime_error("n_cmpl cannot be greater than the number of slots, please increase -np");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
params.n_token_healing_enabled = json_value(data, "n_token_healing_enabled", defaults.n_token_healing_enabled);
|
||||||
|
params.n_token_healing_max_retries = json_value(data, "n_token_healing_max_retries", defaults.n_token_healing_max_retries);
|
||||||
|
|
||||||
return params;
|
return params;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -88,6 +88,10 @@ struct task_params {
|
||||||
|
|
||||||
json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) const;
|
json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) const;
|
||||||
json to_json(bool only_metrics = false) const;
|
json to_json(bool only_metrics = false) const;
|
||||||
|
|
||||||
|
//for completion tooling
|
||||||
|
bool n_token_healing_enabled = false;
|
||||||
|
int32_t n_token_healing_max_retries = 3;
|
||||||
};
|
};
|
||||||
|
|
||||||
// struct for tracking the state of a task (e.g., for streaming)
|
// struct for tracking the state of a task (e.g., for streaming)
|
||||||
|
|
@ -123,6 +127,16 @@ struct task_result_state {
|
||||||
std::vector<common_chat_msg_diff> & diffs);
|
std::vector<common_chat_msg_diff> & diffs);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct token_healing_params {
|
||||||
|
llama_token healing_token;
|
||||||
|
std::string healing_token_text;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct token_healing_result {
|
||||||
|
mutable std::string accepted_healing_token_text;
|
||||||
|
mutable bool token_healing_complete;
|
||||||
|
};
|
||||||
|
|
||||||
struct server_task {
|
struct server_task {
|
||||||
int id = -1; // to be filled by server_queue
|
int id = -1; // to be filled by server_queue
|
||||||
|
|
||||||
|
|
@ -253,6 +267,10 @@ struct server_task {
|
||||||
bool is_child() const {
|
bool is_child() const {
|
||||||
return id_parent != -1;
|
return id_parent != -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
token_healing_params token_healing_params;
|
||||||
|
|
||||||
|
token_healing_result token_healing_result;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct result_timings {
|
struct result_timings {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue