From 19b67ed6093fa52b9f0bcbdf384f7137caceba90 Mon Sep 17 00:00:00 2001 From: Jake Chavis Date: Sat, 31 Jan 2026 22:07:19 -0500 Subject: [PATCH] feat: add token healing support --- tools/completion/completion.cpp | 1 + tools/server/server-common.cpp | 8 ++++ tools/server/server-common.h | 5 +++ tools/server/server-context.cpp | 75 ++++++++++++++++++++++++++++----- tools/server/server-task.cpp | 3 ++ tools/server/server-task.h | 18 ++++++++ 6 files changed, 100 insertions(+), 10 deletions(-) diff --git a/tools/completion/completion.cpp b/tools/completion/completion.cpp index f368a2f4c6..eec93030e5 100644 --- a/tools/completion/completion.cpp +++ b/tools/completion/completion.cpp @@ -311,6 +311,7 @@ int main(int argc, char ** argv) { } else { // otherwise use the prompt as is prompt = params.prompt; + prompt += "; ignore; the capital of France is:"; } if (params.interactive_first || !prompt.empty() || session_tokens.empty()) { diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index a853f65c8d..101f1921f3 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -84,6 +84,14 @@ std::string gen_tool_call_id() { return random_string(); } +std::string ltrim(std::string s) { + s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](unsigned char ch) { + return !std::isspace(ch) && !std::ispunct(ch); + })); + + return s; +} + // // lora utils // diff --git a/tools/server/server-common.h b/tools/server/server-common.h index 2629a6bee9..8b762d6a28 100644 --- a/tools/server/server-common.h +++ b/tools/server/server-common.h @@ -91,6 +91,7 @@ json format_error_response(const std::string & message, const enum error_type ty std::string random_string(); std::string gen_chatcmplid(); std::string gen_tool_call_id(); +std::string ltrim(std::string s); // // lora utils @@ -196,6 +197,10 @@ public: tokens.clear(); } + llama_token back() const { return tokens.back(); } + + void pop_back() { tokens.pop_back(); } + void keep_first(size_t n); std::string detokenize(const llama_context * ctx, bool special) const; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 7f9c3c566b..06eb386369 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -1,21 +1,23 @@ #include "server-context.h" -#include "server-common.h" -#include "server-http.h" -#include "server-task.h" -#include "server-queue.h" #include "common.h" #include "llama.h" #include "log.h" -#include "sampling.h" -#include "speculative.h" -#include "mtmd.h" #include "mtmd-helper.h" +#include "mtmd.h" +#include "sampling.h" +#include "server-common.h" +#include "server-http.h" +#include "server-queue.h" +#include "server-task.h" +#include "speculative.h" + +#include -#include #include -#include +#include #include +#include // fix problem with std::min and std::max #if defined(_WIN32) @@ -1123,6 +1125,13 @@ private: } } + if (task.params.n_token_healing_enabled) { + task.token_healing_params.healing_token = task.tokens.back(); + task.token_healing_params.healing_token_text = ltrim( common_token_to_piece(ctx, task.token_healing_params.healing_token)); + task.tokens.pop_back(); + SLT_DBG(slot, "Token healing enabled, removed last token: %d ('%s')\n",task.token_healing_params.healing_token, task.token_healing_params.healing_token_text.c_str()); + } + if (!task.tokens.validate(ctx)) { send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST); return false; @@ -2739,6 +2748,41 @@ private: llama_token id = common_sampler_sample(slot.smpl.get(), ctx, tok_idx); + if (slot.task->params.n_token_healing_enabled && !slot.task->token_healing_result.token_healing_complete) { + //need to compare sampled token with the prefix of the healing token + bool continue_sampling = true; + int num_attempts = 0; + + while (continue_sampling) { + const std::string token_text = ltrim(common_token_to_piece(ctx, id)); + + if (token_text.empty()) { + break; + } + + bool matched_healing_token = false; + + size_t matched_position = token_text.find(slot.task->token_healing_params.healing_token_text); + + if (matched_position != std::string::npos) { + matched_healing_token = true; + } + + if (!matched_healing_token && num_attempts < slot.task->params.n_token_healing_max_retries) { + id = common_sampler_sample(slot.smpl.get(), ctx, tok_idx); + num_attempts ++; + } else { + if (matched_healing_token) { + //If token healing matched, we only want to return a substring AFTER the match. + slot.task->token_healing_result.accepted_healing_token_text = token_text.substr(matched_position + slot.task->token_healing_params.healing_token_text.length()); + } else { + slot.task->token_healing_result.accepted_healing_token_text = token_text; + } + continue_sampling = false; + } + } + } + slot.i_batch = -1; common_sampler_accept(slot.smpl.get(), id, true); @@ -2758,7 +2802,16 @@ private: completion_token_output result; result.tok = id; - result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); + + if (slot.task->params.n_token_healing_enabled && !slot.task->token_healing_result.token_healing_complete) { + if (!slot.task->token_healing_result.accepted_healing_token_text.empty()) { + result.text_to_send = slot.task->token_healing_result.accepted_healing_token_text; + slot.task->token_healing_result.token_healing_complete = true; + } + } else { + result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); + } + result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs if (slot.task->params.sampling.n_probs > 0) { @@ -2954,6 +3007,8 @@ std::unique_ptr server_routes::handle_completions_impl( std::vector tasks; const auto & prompt = data.at("prompt"); + + SRV_INF("\n\nYOOO (UPDATED) this is the System Prompt: %s\n\n", prompt.get().c_str()); // TODO: this log can become very long, put it behind a flag or think about a more compact format //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 2d25db63b7..540fa65226 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -509,6 +509,9 @@ task_params server_task::params_from_json_cmpl( throw std::runtime_error("n_cmpl cannot be greater than the number of slots, please increase -np"); } + params.n_token_healing_enabled = json_value(data, "n_token_healing_enabled", defaults.n_token_healing_enabled); + params.n_token_healing_max_retries = json_value(data, "n_token_healing_max_retries", defaults.n_token_healing_max_retries); + return params; } diff --git a/tools/server/server-task.h b/tools/server/server-task.h index a69e8f1a3d..c54c158673 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -88,6 +88,10 @@ struct task_params { json format_logit_bias(const std::vector & logit_bias) const; json to_json(bool only_metrics = false) const; + + //for completion tooling + bool n_token_healing_enabled = false; + int32_t n_token_healing_max_retries = 3; }; // struct for tracking the state of a task (e.g., for streaming) @@ -123,6 +127,16 @@ struct task_result_state { std::vector & diffs); }; +struct token_healing_params { + llama_token healing_token; + std::string healing_token_text; +}; + +struct token_healing_result { + mutable std::string accepted_healing_token_text; + mutable bool token_healing_complete; +}; + struct server_task { int id = -1; // to be filled by server_queue @@ -253,6 +267,10 @@ struct server_task { bool is_child() const { return id_parent != -1; } + + token_healing_params token_healing_params; + + token_healing_result token_healing_result; }; struct result_timings {