feat: add token healing support

This commit is contained in:
Jake Chavis 2026-01-31 22:07:19 -05:00
parent 41ea26144e
commit 19b67ed609
6 changed files with 100 additions and 10 deletions

View File

@ -311,6 +311,7 @@ int main(int argc, char ** argv) {
} else {
// otherwise use the prompt as is
prompt = params.prompt;
prompt += "; ignore; the capital of France is:";
}
if (params.interactive_first || !prompt.empty() || session_tokens.empty()) {

View File

@ -84,6 +84,14 @@ std::string gen_tool_call_id() {
return random_string();
}
std::string ltrim(std::string s) {
s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) {
return !std::isspace(ch) && !std::ispunct(ch);
}));
return s;
}
//
// lora utils
//

View File

@ -91,6 +91,7 @@ json format_error_response(const std::string & message, const enum error_type ty
std::string random_string();
std::string gen_chatcmplid();
std::string gen_tool_call_id();
std::string ltrim(std::string s);
//
// lora utils
@ -196,6 +197,10 @@ public:
tokens.clear();
}
llama_token back() const { return tokens.back(); }
void pop_back() { tokens.pop_back(); }
void keep_first(size_t n);
std::string detokenize(const llama_context * ctx, bool special) const;

View File

@ -1,21 +1,23 @@
#include "server-context.h"
#include "server-common.h"
#include "server-http.h"
#include "server-task.h"
#include "server-queue.h"
#include "common.h"
#include "llama.h"
#include "log.h"
#include "sampling.h"
#include "speculative.h"
#include "mtmd.h"
#include "mtmd-helper.h"
#include "mtmd.h"
#include "sampling.h"
#include "server-common.h"
#include "server-http.h"
#include "server-queue.h"
#include "server-task.h"
#include "speculative.h"
#include <bsm/audit.h>
#include <cstddef>
#include <cinttypes>
#include <memory>
#include <cstddef>
#include <filesystem>
#include <memory>
// fix problem with std::min and std::max
#if defined(_WIN32)
@ -1123,6 +1125,13 @@ private:
}
}
if (task.params.n_token_healing_enabled) {
task.token_healing_params.healing_token = task.tokens.back();
task.token_healing_params.healing_token_text = ltrim( common_token_to_piece(ctx, task.token_healing_params.healing_token));
task.tokens.pop_back();
SLT_DBG(slot, "Token healing enabled, removed last token: %d ('%s')\n",task.token_healing_params.healing_token, task.token_healing_params.healing_token_text.c_str());
}
if (!task.tokens.validate(ctx)) {
send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST);
return false;
@ -2739,6 +2748,41 @@ private:
llama_token id = common_sampler_sample(slot.smpl.get(), ctx, tok_idx);
if (slot.task->params.n_token_healing_enabled && !slot.task->token_healing_result.token_healing_complete) {
//need to compare sampled token with the prefix of the healing token
bool continue_sampling = true;
int num_attempts = 0;
while (continue_sampling) {
const std::string token_text = ltrim(common_token_to_piece(ctx, id));
if (token_text.empty()) {
break;
}
bool matched_healing_token = false;
size_t matched_position = token_text.find(slot.task->token_healing_params.healing_token_text);
if (matched_position != std::string::npos) {
matched_healing_token = true;
}
if (!matched_healing_token && num_attempts < slot.task->params.n_token_healing_max_retries) {
id = common_sampler_sample(slot.smpl.get(), ctx, tok_idx);
num_attempts ++;
} else {
if (matched_healing_token) {
//If token healing matched, we only want to return a substring AFTER the match.
slot.task->token_healing_result.accepted_healing_token_text = token_text.substr(matched_position + slot.task->token_healing_params.healing_token_text.length());
} else {
slot.task->token_healing_result.accepted_healing_token_text = token_text;
}
continue_sampling = false;
}
}
}
slot.i_batch = -1;
common_sampler_accept(slot.smpl.get(), id, true);
@ -2758,7 +2802,16 @@ private:
completion_token_output result;
result.tok = id;
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
if (slot.task->params.n_token_healing_enabled && !slot.task->token_healing_result.token_healing_complete) {
if (!slot.task->token_healing_result.accepted_healing_token_text.empty()) {
result.text_to_send = slot.task->token_healing_result.accepted_healing_token_text;
slot.task->token_healing_result.token_healing_complete = true;
}
} else {
result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));
}
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
if (slot.task->params.sampling.n_probs > 0) {
@ -2954,6 +3007,8 @@ std::unique_ptr<server_res_generator> server_routes::handle_completions_impl(
std::vector<server_task> tasks;
const auto & prompt = data.at("prompt");
SRV_INF("\n\nYOOO (UPDATED) this is the System Prompt: %s\n\n", prompt.get<std::string>().c_str());
// TODO: this log can become very long, put it behind a flag or think about a more compact format
//SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get<std::string>().c_str() : prompt.dump(2).c_str());

View File

@ -509,6 +509,9 @@ task_params server_task::params_from_json_cmpl(
throw std::runtime_error("n_cmpl cannot be greater than the number of slots, please increase -np");
}
params.n_token_healing_enabled = json_value(data, "n_token_healing_enabled", defaults.n_token_healing_enabled);
params.n_token_healing_max_retries = json_value(data, "n_token_healing_max_retries", defaults.n_token_healing_max_retries);
return params;
}

View File

@ -88,6 +88,10 @@ struct task_params {
json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) const;
json to_json(bool only_metrics = false) const;
//for completion tooling
bool n_token_healing_enabled = false;
int32_t n_token_healing_max_retries = 3;
};
// struct for tracking the state of a task (e.g., for streaming)
@ -123,6 +127,16 @@ struct task_result_state {
std::vector<common_chat_msg_diff> & diffs);
};
struct token_healing_params {
llama_token healing_token;
std::string healing_token_text;
};
struct token_healing_result {
mutable std::string accepted_healing_token_text;
mutable bool token_healing_complete;
};
struct server_task {
int id = -1; // to be filled by server_queue
@ -253,6 +267,10 @@ struct server_task {
bool is_child() const {
return id_parent != -1;
}
token_healing_params token_healing_params;
token_healing_result token_healing_result;
};
struct result_timings {