Merge a67f73c880 into 3bc8d2cf23
This commit is contained in:
commit
59753f5c46
|
|
@ -84,6 +84,14 @@ std::string gen_tool_call_id() {
|
|||
return random_string();
|
||||
}
|
||||
|
||||
std::string ltrim(std::string s) {
|
||||
s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) {
|
||||
return !std::isspace(ch) && !std::ispunct(ch);
|
||||
}));
|
||||
|
||||
return s;
|
||||
}
|
||||
|
||||
//
|
||||
// lora utils
|
||||
//
|
||||
|
|
|
|||
|
|
@ -91,6 +91,7 @@ json format_error_response(const std::string & message, const enum error_type ty
|
|||
std::string random_string();
|
||||
std::string gen_chatcmplid();
|
||||
std::string gen_tool_call_id();
|
||||
std::string ltrim(std::string s);
|
||||
|
||||
//
|
||||
// lora utils
|
||||
|
|
@ -196,6 +197,10 @@ public:
|
|||
tokens.clear();
|
||||
}
|
||||
|
||||
llama_token back() const { return tokens.back(); }
|
||||
|
||||
void pop_back() { tokens.pop_back(); }
|
||||
|
||||
void keep_first(size_t n);
|
||||
|
||||
std::string detokenize(const llama_context * ctx, bool special) const;
|
||||
|
|
|
|||
|
|
@ -13,6 +13,9 @@
|
|||
#include "mtmd-helper.h"
|
||||
|
||||
#include <cstddef>
|
||||
|
||||
#include <bsm/audit.h>
|
||||
|
||||
#include <cinttypes>
|
||||
#include <memory>
|
||||
#include <filesystem>
|
||||
|
|
@ -1123,6 +1126,13 @@ private:
|
|||
}
|
||||
}
|
||||
|
||||
if (task.params.n_token_healing_enabled) {
|
||||
task.token_healing_params.healing_token = task.tokens.back();
|
||||
task.token_healing_params.healing_token_text = ltrim(common_token_to_piece(ctx, task.token_healing_params.healing_token));
|
||||
task.tokens.pop_back();
|
||||
SLT_DBG(slot, "Token healing enabled, removed last token: %d ('%s')\n",task.token_healing_params.healing_token, task.token_healing_params.healing_token_text.c_str());
|
||||
}
|
||||
|
||||
if (!task.tokens.validate(ctx)) {
|
||||
send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
|
|
@ -2739,6 +2749,41 @@ private:
|
|||
|
||||
llama_token id = common_sampler_sample(slot.smpl.get(), ctx, tok_idx);
|
||||
|
||||
if (slot.task->params.n_token_healing_enabled && !slot.task->token_healing_result.token_healing_complete) {
|
||||
//need to compare sampled token with the prefix of the healing token
|
||||
bool continue_sampling = true;
|
||||
int num_attempts = 0;
|
||||
|
||||
while (continue_sampling) {
|
||||
const std::string token_text = ltrim(common_token_to_piece(ctx, id));
|
||||
|
||||
if (token_text.empty()) {
|
||||
break;
|
||||
}
|
||||
|
||||
bool matched_healing_token = false;
|
||||
|
||||
size_t matched_position = token_text.find(slot.task->token_healing_params.healing_token_text);
|
||||
|
||||
if (matched_position != std::string::npos) {
|
||||
matched_healing_token = true;
|
||||
}
|
||||
|
||||
if (!matched_healing_token && num_attempts < slot.task->params.n_token_healing_max_retries) {
|
||||
id = common_sampler_sample(slot.smpl.get(), ctx, tok_idx);
|
||||
num_attempts ++;
|
||||
} else {
|
||||
if (matched_healing_token) {
|
||||
//If token healing matched, we only want to return a substring AFTER the match.
|
||||
slot.task->token_healing_result.accepted_healing_token_text = token_text.substr(matched_position + slot.task->token_healing_params.healing_token_text.length());
|
||||
} else {
|
||||
slot.task->token_healing_result.accepted_healing_token_text = token_text;
|
||||
}
|
||||
continue_sampling = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
slot.i_batch = -1;
|
||||
|
||||
common_sampler_accept(slot.smpl.get(), id, true);
|
||||
|
|
@ -2758,7 +2803,16 @@ private:
|
|||
|
||||
completion_token_output result;
|
||||
result.tok = id;
|
||||
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
|
||||
|
||||
if (slot.task->params.n_token_healing_enabled && !slot.task->token_healing_result.token_healing_complete) {
|
||||
if (!slot.task->token_healing_result.accepted_healing_token_text.empty()) {
|
||||
result.text_to_send = slot.task->token_healing_result.accepted_healing_token_text;
|
||||
slot.task->token_healing_result.token_healing_complete = true;
|
||||
}
|
||||
} else {
|
||||
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
|
||||
}
|
||||
|
||||
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
|
||||
|
||||
if (slot.task->params.sampling.n_probs > 0) {
|
||||
|
|
|
|||
|
|
@ -509,6 +509,9 @@ task_params server_task::params_from_json_cmpl(
|
|||
throw std::runtime_error("n_cmpl cannot be greater than the number of slots, please increase -np");
|
||||
}
|
||||
|
||||
params.n_token_healing_enabled = json_value(data, "n_token_healing_enabled", defaults.n_token_healing_enabled);
|
||||
params.n_token_healing_max_retries = json_value(data, "n_token_healing_max_retries", defaults.n_token_healing_max_retries);
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -88,6 +88,10 @@ struct task_params {
|
|||
|
||||
json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) const;
|
||||
json to_json(bool only_metrics = false) const;
|
||||
|
||||
//for completion tooling
|
||||
bool n_token_healing_enabled = false;
|
||||
int32_t n_token_healing_max_retries = 3;
|
||||
};
|
||||
|
||||
// struct for tracking the state of a task (e.g., for streaming)
|
||||
|
|
@ -123,6 +127,16 @@ struct task_result_state {
|
|||
std::vector<common_chat_msg_diff> & diffs);
|
||||
};
|
||||
|
||||
struct token_healing_params {
|
||||
llama_token healing_token;
|
||||
std::string healing_token_text;
|
||||
};
|
||||
|
||||
struct token_healing_result {
|
||||
mutable std::string accepted_healing_token_text;
|
||||
mutable bool token_healing_complete;
|
||||
};
|
||||
|
||||
struct server_task {
|
||||
int id = -1; // to be filled by server_queue
|
||||
|
||||
|
|
@ -253,6 +267,10 @@ struct server_task {
|
|||
bool is_child() const {
|
||||
return id_parent != -1;
|
||||
}
|
||||
|
||||
token_healing_params token_healing_params;
|
||||
|
||||
token_healing_result token_healing_result;
|
||||
};
|
||||
|
||||
struct result_timings {
|
||||
|
|
|
|||
Loading…
Reference in New Issue