Merge 02d4c32517 into 825eb91a66
This commit is contained in:
commit
26cc6f60e6
|
|
@ -3100,8 +3100,19 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
"message injected before the end-of-thinking tag when reasoning budget is exhausted (default: none)",
|
||||
[](common_params & params, const std::string & value) {
|
||||
params.reasoning_budget_message = value;
|
||||
params.sampling.reasoning_budget_message = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET_MESSAGE"));
|
||||
add_opt(common_arg(
|
||||
{"--reasoning-budget-conclusion"}, "N",
|
||||
"token budget for conclusion phase after message injection (0 = disabled, default: 0).\n"
|
||||
"When set, the model is given N tokens to conclude naturally after the budget message\n"
|
||||
"is injected, instead of immediately forcing the end-of-thinking tag.",
|
||||
[](common_params & params, int value) {
|
||||
if (value < 0) { throw std::invalid_argument("invalid value for --reasoning-budget-conclusion: must be >= 0"); }
|
||||
params.sampling.reasoning_budget_conclusion = value;
|
||||
}
|
||||
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_THINK_BUDGET_CONCLUSION"));
|
||||
add_opt(common_arg(
|
||||
{"--chat-template"}, "JINJA_TEMPLATE",
|
||||
string_format(
|
||||
|
|
@ -3894,4 +3905,4 @@ void common_params_add_preset_options(std::vector<common_arg> & args) {
|
|||
// "in server router mode, do not unload this model if models_max is exceeded",
|
||||
// [](common_params &) { /* unused */ }
|
||||
// ).set_preset_only());
|
||||
}
|
||||
}
|
||||
|
|
@ -283,10 +283,12 @@ struct common_params_sampling {
|
|||
|
||||
// reasoning budget sampler parameters
|
||||
// these are populated by the server/CLI based on chat template params
|
||||
int32_t reasoning_budget_tokens = -1; // -1 = disabled, >= 0 = token budget
|
||||
std::vector<llama_token> reasoning_budget_start; // start tag token sequence
|
||||
std::vector<llama_token> reasoning_budget_end; // end tag token sequence
|
||||
std::vector<llama_token> reasoning_budget_forced; // forced sequence (message + end tag)
|
||||
int32_t reasoning_budget_tokens = -1; // -1 = disabled, >= 0 = token budget
|
||||
int32_t reasoning_budget_conclusion = 0; // tokens reserved for conclusion phase (0 = disabled)
|
||||
std::string reasoning_budget_message; // message injected at start of conclusion phase
|
||||
std::vector<llama_token> reasoning_budget_start; // start tag token sequence
|
||||
std::vector<llama_token> reasoning_budget_end; // end tag token sequence
|
||||
std::vector<llama_token> reasoning_budget_forced; // forced sequence (end tag, hard-cutoff safety net)
|
||||
|
||||
bool backend_sampling = false;
|
||||
|
||||
|
|
@ -995,4 +997,4 @@ inline llama_model_tensor_buft_override llm_ffn_exps_cpu_override() {
|
|||
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride);
|
||||
|
||||
// "adamw" or "sgd" (case insensitive)
|
||||
enum ggml_opt_optimizer_type common_opt_get_optimizer(const char *);
|
||||
enum ggml_opt_optimizer_type common_opt_get_optimizer(const char *);
|
||||
|
|
@ -41,15 +41,18 @@ struct common_reasoning_budget_ctx {
|
|||
|
||||
token_matcher start_matcher;
|
||||
token_matcher end_matcher;
|
||||
std::vector<llama_token> forced_tokens;
|
||||
std::vector<llama_token> forced_tokens; // end-of-thinking sequence (hard-cutoff safety net)
|
||||
std::vector<llama_token> message_tokens; // message injected at conclusion phase start
|
||||
|
||||
int32_t budget; // maximum tokens in reasoning block
|
||||
int32_t remaining; // tokens remaining in budget
|
||||
int32_t budget; // maximum tokens in thinking phase
|
||||
int32_t conclusion_budget; // tokens reserved for conclusion phase (0 = disabled)
|
||||
int32_t remaining; // tokens remaining in current phase
|
||||
|
||||
common_reasoning_budget_state state;
|
||||
|
||||
// for forcing
|
||||
size_t force_pos; // next position in forced_tokens to force
|
||||
size_t force_pos; // next position in forced_tokens to force
|
||||
size_t message_pos; // next position in message_tokens to force (during message injection)
|
||||
};
|
||||
|
||||
static const char * common_reasoning_budget_name(const struct llama_sampler * /*smpl*/) {
|
||||
|
|
@ -76,6 +79,7 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
|
|||
break;
|
||||
}
|
||||
case REASONING_BUDGET_COUNTING:
|
||||
case REASONING_BUDGET_CONCLUDING:
|
||||
case REASONING_BUDGET_WAITING_UTF8:
|
||||
{
|
||||
if (ctx->end_matcher.advance(token)) {
|
||||
|
|
@ -97,10 +101,36 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
|
|||
ctx->end_matcher.reset();
|
||||
LOG_INF("reasoning-budget: UTF-8 complete, now forcing end sequence\n");
|
||||
}
|
||||
} else if (ctx->state == REASONING_BUDGET_COUNTING) {
|
||||
} else if (ctx->state == REASONING_BUDGET_CONCLUDING) {
|
||||
ctx->remaining--;
|
||||
if (ctx->remaining <= 0) {
|
||||
if (utf8_complete) {
|
||||
ctx->state = REASONING_BUDGET_FORCING;
|
||||
ctx->force_pos = 0;
|
||||
ctx->end_matcher.reset();
|
||||
LOG_INF("reasoning-budget: conclusion budget exhausted, forcing end sequence\n");
|
||||
} else {
|
||||
ctx->state = REASONING_BUDGET_WAITING_UTF8;
|
||||
ctx->end_matcher.reset();
|
||||
LOG_INF("reasoning-budget: conclusion budget exhausted, waiting for UTF-8 completion\n");
|
||||
}
|
||||
}
|
||||
} else if (ctx->state == REASONING_BUDGET_COUNTING) {
|
||||
ctx->remaining--;
|
||||
if (ctx->remaining <= 0) {
|
||||
if (ctx->conclusion_budget > 0 && !ctx->message_tokens.empty()) {
|
||||
// Two-phase: force message tokens first, then let model conclude
|
||||
ctx->state = REASONING_BUDGET_INJECTING;
|
||||
ctx->message_pos = 0;
|
||||
ctx->end_matcher.reset();
|
||||
LOG_INF("reasoning-budget: thinking budget exhausted, injecting conclusion message\n");
|
||||
} else if (ctx->conclusion_budget > 0) {
|
||||
// No message, but conclusion budget set — go straight to free conclusion
|
||||
ctx->state = REASONING_BUDGET_CONCLUDING;
|
||||
ctx->remaining = ctx->conclusion_budget;
|
||||
ctx->end_matcher.reset();
|
||||
LOG_INF("reasoning-budget: thinking budget exhausted, entering conclusion phase (%d tokens)\n", ctx->conclusion_budget);
|
||||
} else if (utf8_complete) {
|
||||
ctx->state = REASONING_BUDGET_FORCING;
|
||||
ctx->force_pos = 0;
|
||||
ctx->end_matcher.reset();
|
||||
|
|
@ -114,6 +144,16 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
|
|||
}
|
||||
break;
|
||||
}
|
||||
case REASONING_BUDGET_INJECTING:
|
||||
ctx->message_pos++;
|
||||
if (ctx->message_pos >= ctx->message_tokens.size()) {
|
||||
// Message fully injected — enter free conclusion phase
|
||||
ctx->state = REASONING_BUDGET_CONCLUDING;
|
||||
ctx->remaining = ctx->conclusion_budget;
|
||||
ctx->end_matcher.reset();
|
||||
LOG_INF("reasoning-budget: message injected, entering conclusion phase (%d tokens)\n", ctx->conclusion_budget);
|
||||
}
|
||||
break;
|
||||
case REASONING_BUDGET_FORCING:
|
||||
ctx->force_pos++;
|
||||
if (ctx->force_pos >= ctx->forced_tokens.size()) {
|
||||
|
|
@ -129,6 +169,19 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
|
|||
static void common_reasoning_budget_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||
auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
|
||||
|
||||
if (ctx->state == REASONING_BUDGET_INJECTING) {
|
||||
if (ctx->message_pos >= ctx->message_tokens.size()) {
|
||||
return;
|
||||
}
|
||||
const llama_token forced = ctx->message_tokens[ctx->message_pos];
|
||||
for (size_t i = 0; i < cur_p->size; i++) {
|
||||
if (cur_p->data[i].id != forced) {
|
||||
cur_p->data[i].logit = -INFINITY;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (ctx->state != REASONING_BUDGET_FORCING) {
|
||||
// passthrough — don't modify logits
|
||||
return;
|
||||
|
|
@ -155,13 +208,15 @@ static void common_reasoning_budget_reset(struct llama_sampler * smpl) {
|
|||
ctx->start_matcher.reset();
|
||||
ctx->end_matcher.reset();
|
||||
ctx->force_pos = 0;
|
||||
ctx->message_pos = 0;
|
||||
}
|
||||
|
||||
// forward declaration for use in clone
|
||||
static struct llama_sampler * common_reasoning_budget_init_state(
|
||||
const struct llama_vocab * vocab, const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens, const std::vector<llama_token> & forced_tokens,
|
||||
int32_t budget, common_reasoning_budget_state initial_state);
|
||||
const std::vector<llama_token> & message_tokens,
|
||||
int32_t budget, int32_t conclusion_budget, common_reasoning_budget_state initial_state);
|
||||
|
||||
static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx;
|
||||
|
|
@ -170,7 +225,9 @@ static struct llama_sampler * common_reasoning_budget_clone(const struct llama_s
|
|||
ctx->start_matcher.tokens,
|
||||
ctx->end_matcher.tokens,
|
||||
ctx->forced_tokens,
|
||||
ctx->message_tokens,
|
||||
ctx->budget,
|
||||
ctx->conclusion_budget,
|
||||
ctx->state);
|
||||
}
|
||||
|
||||
|
|
@ -196,7 +253,9 @@ static struct llama_sampler * common_reasoning_budget_init_state(
|
|||
const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
const std::vector<llama_token> & message_tokens,
|
||||
int32_t budget,
|
||||
int32_t conclusion_budget,
|
||||
common_reasoning_budget_state initial_state) {
|
||||
// promote COUNTING with budget <= 0 to FORCING
|
||||
if (initial_state == REASONING_BUDGET_COUNTING && budget <= 0) {
|
||||
|
|
@ -206,14 +265,17 @@ static struct llama_sampler * common_reasoning_budget_init_state(
|
|||
return llama_sampler_init(
|
||||
/* .iface = */ &common_reasoning_budget_i,
|
||||
/* .ctx = */ new common_reasoning_budget_ctx {
|
||||
/* .vocab = */ vocab,
|
||||
/* .start_matcher = */ { start_tokens, 0 },
|
||||
/* .end_matcher = */ { end_tokens, 0 },
|
||||
/* .forced_tokens = */ forced_tokens,
|
||||
/* .budget = */ budget,
|
||||
/* .remaining = */ budget,
|
||||
/* .state = */ initial_state,
|
||||
/* .force_pos = */ 0,
|
||||
/* .vocab = */ vocab,
|
||||
/* .start_matcher = */ { start_tokens, 0 },
|
||||
/* .end_matcher = */ { end_tokens, 0 },
|
||||
/* .forced_tokens = */ forced_tokens,
|
||||
/* .message_tokens = */ message_tokens,
|
||||
/* .budget = */ budget,
|
||||
/* .conclusion_budget = */ conclusion_budget,
|
||||
/* .remaining = */ budget,
|
||||
/* .state = */ initial_state,
|
||||
/* .force_pos = */ 0,
|
||||
/* .message_pos = */ 0,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
|
@ -223,7 +285,9 @@ struct llama_sampler * common_reasoning_budget_init(
|
|||
const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
const std::vector<llama_token> & message_tokens,
|
||||
int32_t budget,
|
||||
int32_t conclusion_budget,
|
||||
const std::vector<llama_token> & prefill_tokens) {
|
||||
// Determine initial state from prefill: COUNTING if the prefill begins with
|
||||
// the start sequence but does not also contain the end sequence after it.
|
||||
|
|
@ -243,7 +307,7 @@ struct llama_sampler * common_reasoning_budget_init(
|
|||
}
|
||||
}
|
||||
}
|
||||
return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state);
|
||||
return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, message_tokens, budget, conclusion_budget, initial_state);
|
||||
}
|
||||
|
||||
struct llama_sampler * common_reasoning_budget_init(
|
||||
|
|
@ -251,9 +315,11 @@ struct llama_sampler * common_reasoning_budget_init(
|
|||
const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
const std::vector<llama_token> & message_tokens,
|
||||
int32_t budget,
|
||||
int32_t conclusion_budget,
|
||||
common_reasoning_budget_state initial_state) {
|
||||
return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state);
|
||||
return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, message_tokens, budget, conclusion_budget, initial_state);
|
||||
}
|
||||
|
||||
common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl) {
|
||||
|
|
@ -261,4 +327,4 @@ common_reasoning_budget_state common_reasoning_budget_get_state(const struct lla
|
|||
return REASONING_BUDGET_IDLE;
|
||||
}
|
||||
return ((const common_reasoning_budget_ctx *)smpl->ctx)->state;
|
||||
}
|
||||
}
|
||||
|
|
@ -8,7 +8,9 @@
|
|||
enum common_reasoning_budget_state {
|
||||
REASONING_BUDGET_IDLE, // waiting for start sequence
|
||||
REASONING_BUDGET_COUNTING, // counting down tokens
|
||||
REASONING_BUDGET_FORCING, // forcing budget message + end sequence
|
||||
REASONING_BUDGET_INJECTING, // forcing message tokens before conclusion phase
|
||||
REASONING_BUDGET_CONCLUDING, // conclusion phase: model free to conclude naturally
|
||||
REASONING_BUDGET_FORCING, // forcing end sequence (hard cutoff safety net)
|
||||
REASONING_BUDGET_WAITING_UTF8, // budget exhausted, waiting for UTF-8 completion
|
||||
REASONING_BUDGET_DONE, // passthrough forever
|
||||
};
|
||||
|
|
@ -16,31 +18,36 @@ enum common_reasoning_budget_state {
|
|||
// Creates a reasoning budget sampler that limits token generation inside a
|
||||
// reasoning block (e.g. between <think> and </think>).
|
||||
//
|
||||
// State machine: IDLE -> COUNTING -> WAITING_UTF8 -> FORCING -> DONE
|
||||
// State machine: IDLE -> COUNTING -> CONCLUDING -> WAITING_UTF8 -> FORCING -> DONE
|
||||
// IDLE: passthrough, watching for start_tokens sequence
|
||||
// COUNTING: counting down remaining tokens, watching for natural end_tokens
|
||||
// CONCLUDING: conclusion phase after message injection; model generates freely
|
||||
// until it produces end_tokens naturally or conclusion_budget runs out
|
||||
// WAITING_UTF8: budget exhausted, allowing tokens to complete a UTF-8 sequence
|
||||
// FORCING: forces forced_tokens token-by-token (all other logits -> -inf)
|
||||
// DONE: passthrough forever
|
||||
//
|
||||
// Parameters:
|
||||
// vocab - vocabulary (used for UTF-8 boundary detection; can be nullptr)
|
||||
// start_tokens - token sequence that activates counting
|
||||
// end_tokens - token sequence for natural deactivation
|
||||
// forced_tokens - token sequence forced when budget expires
|
||||
// budget - max tokens allowed in the reasoning block
|
||||
// prefill_tokens - tokens already present in the prompt (generation prompt);
|
||||
// used to determine the initial state: COUNTING if they begin
|
||||
// with start_tokens (but don't also end with end_tokens),
|
||||
// IDLE otherwise. COUNTING with budget <= 0 is promoted to FORCING.
|
||||
// vocab - vocabulary (used for UTF-8 boundary detection; can be nullptr)
|
||||
// start_tokens - token sequence that activates counting
|
||||
// end_tokens - token sequence for natural deactivation
|
||||
// forced_tokens - token sequence forced when budget expires (hard-cutoff safety net)
|
||||
// budget - max tokens allowed in the thinking phase
|
||||
// conclusion_budget - tokens reserved for conclusion phase (0 = disabled, original behavior)
|
||||
// prefill_tokens - tokens already present in the prompt (generation prompt);
|
||||
// used to determine the initial state: COUNTING if they begin
|
||||
// with start_tokens (but don't also end with end_tokens),
|
||||
// IDLE otherwise. COUNTING with budget <= 0 is promoted to FORCING.
|
||||
//
|
||||
struct llama_sampler * common_reasoning_budget_init(
|
||||
const struct llama_vocab * vocab,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
const std::vector<llama_token> & message_tokens,
|
||||
int32_t budget,
|
||||
const std::vector<llama_token> & prefill_tokens = {});
|
||||
int32_t conclusion_budget = 0,
|
||||
const std::vector<llama_token> & prefill_tokens = {});
|
||||
|
||||
// Variant that takes an explicit initial state (used by tests and clone).
|
||||
// COUNTING with budget <= 0 is promoted to FORCING.
|
||||
|
|
@ -49,7 +56,9 @@ struct llama_sampler * common_reasoning_budget_init(
|
|||
const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
const std::vector<llama_token> & message_tokens,
|
||||
int32_t budget,
|
||||
int32_t conclusion_budget,
|
||||
common_reasoning_budget_state initial_state);
|
||||
|
||||
common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl);
|
||||
common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl);
|
||||
|
|
@ -289,12 +289,21 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
|||
|
||||
// reasoning budget sampler
|
||||
if (!params.reasoning_budget_start.empty() && !params.reasoning_budget_end.empty()) {
|
||||
// Tokenize the budget message separately so it can be injected before the
|
||||
// conclusion phase (two-phase graceful termination). The forced_tokens
|
||||
// sequence (message + end tag) is retained as the hard-cutoff safety net.
|
||||
std::vector<llama_token> message_tokens;
|
||||
if (!params.reasoning_budget_message.empty() && params.reasoning_budget_conclusion > 0) {
|
||||
message_tokens = common_tokenize(vocab, params.reasoning_budget_message, false, true);
|
||||
}
|
||||
rbudget = common_reasoning_budget_init(
|
||||
vocab,
|
||||
params.reasoning_budget_start,
|
||||
params.reasoning_budget_end,
|
||||
params.reasoning_budget_forced,
|
||||
message_tokens,
|
||||
params.reasoning_budget_tokens < 0 ? INT_MAX : params.reasoning_budget_tokens,
|
||||
params.reasoning_budget_conclusion,
|
||||
prefill_tokens);
|
||||
}
|
||||
|
||||
|
|
@ -829,4 +838,4 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
|
|||
}
|
||||
|
||||
return samplers;
|
||||
}
|
||||
}
|
||||
|
|
@ -14,89 +14,69 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
// Reasoning budget sampler test helper
|
||||
// These tests use nullptr vocab which safely falls back to treating all tokens as complete
|
||||
// (The UTF-8 boundary detection logic is tested separately in test_utf8_boundary_detection)
|
||||
static void test_reasoning_budget(
|
||||
const char * test_name,
|
||||
const std::vector<llama_token> & sequence,
|
||||
const std::vector<llama_token> & start_tokens,
|
||||
const std::vector<llama_token> & end_tokens,
|
||||
const std::vector<llama_token> & forced_tokens,
|
||||
const std::vector<llama_token> & message_tokens,
|
||||
int32_t budget,
|
||||
int32_t conclusion_budget,
|
||||
common_reasoning_budget_state initial_state,
|
||||
size_t expected_force_start, // token index where forcing should start (SIZE_MAX = never)
|
||||
size_t expected_force_end // token index where forcing should end (after this, no more forcing)
|
||||
size_t expected_force_start,
|
||||
size_t expected_force_end
|
||||
) {
|
||||
// Find the maximum token ID to ensure our vocab covers all tokens
|
||||
llama_token max_token = 0;
|
||||
for (auto t : sequence) max_token = std::max(max_token, t);
|
||||
for (auto t : start_tokens) max_token = std::max(max_token, t);
|
||||
for (auto t : end_tokens) max_token = std::max(max_token, t);
|
||||
for (auto t : forced_tokens) max_token = std::max(max_token, t);
|
||||
for (size_t k = 0; k < sequence.size(); k++) { if (sequence[k] > max_token) max_token = sequence[k]; }
|
||||
for (size_t k = 0; k < start_tokens.size(); k++) { if (start_tokens[k] > max_token) max_token = start_tokens[k]; }
|
||||
for (size_t k = 0; k < end_tokens.size(); k++) { if (end_tokens[k] > max_token) max_token = end_tokens[k]; }
|
||||
for (size_t k = 0; k < forced_tokens.size(); k++) { if (forced_tokens[k] > max_token) max_token = forced_tokens[k]; }
|
||||
for (size_t k = 0; k < message_tokens.size();k++) { if (message_tokens[k]> max_token) max_token = message_tokens[k];}
|
||||
|
||||
// Create a minimal sampler with mock vocabulary
|
||||
// For this test, we use nullptr as vocab since we're testing state transitions
|
||||
// The UTF-8 boundary check will treat all tokens as complete (safe fallback)
|
||||
auto * sampler = common_reasoning_budget_init(
|
||||
nullptr, // vocab - not used for basic state machine tests
|
||||
start_tokens,
|
||||
end_tokens,
|
||||
forced_tokens,
|
||||
budget,
|
||||
initial_state
|
||||
nullptr,
|
||||
start_tokens, end_tokens, forced_tokens, message_tokens,
|
||||
budget, conclusion_budget, initial_state
|
||||
);
|
||||
|
||||
// Create a test token data array for checking forcing behavior
|
||||
// Vocab size must be large enough to include all tokens (start, end, forced, sequence)
|
||||
std::vector<llama_token_data> cur;
|
||||
const size_t n_vocab = (size_t)max_token + 1;
|
||||
const size_t n_vocab = (size_t)(max_token + 1);
|
||||
for (size_t i = 0; i < n_vocab; i++) {
|
||||
cur.emplace_back(llama_token_data{(llama_token)i, logf((float)(i+1)), 0.0f});
|
||||
llama_token_data d;
|
||||
d.id = (llama_token)i; d.logit = logf((float)(i+1)); d.p = 0.0f;
|
||||
cur.push_back(d);
|
||||
}
|
||||
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
|
||||
|
||||
size_t actual_force_start = SIZE_MAX;
|
||||
size_t actual_force_end = SIZE_MAX;
|
||||
size_t actual_force_end = SIZE_MAX;
|
||||
|
||||
// Feed the sequence and track when forcing occurs
|
||||
for (size_t i = 0; i < sequence.size(); i++) {
|
||||
// Check if we're in forcing state by applying and seeing if logits are modified
|
||||
cur_p.selected = -1;
|
||||
for (size_t j = 0; j < cur.size(); j++) {
|
||||
cur[j].logit = logf((float)(j+1)); // reset logits
|
||||
}
|
||||
for (size_t j = 0; j < cur.size(); j++) { cur[j].logit = logf((float)(j+1)); }
|
||||
|
||||
llama_sampler_apply(sampler, &cur_p);
|
||||
|
||||
// Check if forcing is active (all logits except one should be -INFINITY)
|
||||
size_t finite_count = 0;
|
||||
llama_token finite_token = -1;
|
||||
for (size_t j = 0; j < cur.size(); j++) {
|
||||
if (std::isfinite(cur[j].logit)) {
|
||||
finite_count++;
|
||||
finite_token = cur[j].id;
|
||||
}
|
||||
if (std::isfinite(cur[j].logit)) { finite_count++; finite_token = cur[j].id; }
|
||||
}
|
||||
|
||||
llama_sampler_accept(sampler, sequence[i]);
|
||||
|
||||
fprintf(stderr, " i=%zu: token=%d, finite_count=%zu, finite_token=%d\n", i, (int)sequence[i], finite_count, (int)finite_token);
|
||||
fprintf(stderr, " i=%zu: token=%d, finite_count=%zu, finite_token=%d\n",
|
||||
i, (int)sequence[i], finite_count, (int)finite_token);
|
||||
|
||||
if (finite_count == 1) {
|
||||
if (actual_force_start == SIZE_MAX) {
|
||||
actual_force_start = i;
|
||||
}
|
||||
if (actual_force_start == SIZE_MAX) { actual_force_start = i; }
|
||||
actual_force_end = i;
|
||||
} else if (actual_force_start != SIZE_MAX && actual_force_end != SIZE_MAX) {
|
||||
// Forcing stopped
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
llama_sampler_free(sampler);
|
||||
|
||||
// Verify forcing occurred at expected positions
|
||||
if (expected_force_start == SIZE_MAX) {
|
||||
if (actual_force_start != SIZE_MAX) {
|
||||
fprintf(stderr, "Test '%s' FAILED: Expected no forcing, but forcing occurred at %zu\n", test_name, actual_force_start);
|
||||
|
|
@ -112,126 +92,110 @@ static void test_reasoning_budget(
|
|||
GGML_ASSERT(false && "Forcing started at wrong position");
|
||||
}
|
||||
}
|
||||
|
||||
if (expected_force_end != SIZE_MAX) {
|
||||
if (actual_force_end < expected_force_end) {
|
||||
fprintf(stderr, "Test '%s' FAILED: Forcing ended at %zu, expected >= %zu\n", test_name, actual_force_end, expected_force_end);
|
||||
GGML_ASSERT(false && "Forcing ended too early");
|
||||
}
|
||||
if (expected_force_end != SIZE_MAX && actual_force_end < expected_force_end) {
|
||||
fprintf(stderr, "Test '%s' FAILED: Forcing ended at %zu, expected >= %zu\n", test_name, actual_force_end, expected_force_end);
|
||||
GGML_ASSERT(false && "Forcing ended too early");
|
||||
}
|
||||
|
||||
fprintf(stderr, " Test '%s' passed (force_start=%zu, force_end=%zu)\n", test_name, actual_force_start, actual_force_end);
|
||||
(void)sequence;
|
||||
}
|
||||
|
||||
// UTF-8 boundary detection unit test
|
||||
// Tests common_utf8_is_complete() from reasoning-budget.h
|
||||
static void test_utf8_boundary_detection() {
|
||||
// Complete sequences
|
||||
GGML_ASSERT(common_utf8_is_complete("hello"));
|
||||
GGML_ASSERT(common_utf8_is_complete(""));
|
||||
GGML_ASSERT(common_utf8_is_complete("\xC2\xA0")); // complete 2-byte UTF-8 (U+00A0)
|
||||
GGML_ASSERT(common_utf8_is_complete("\xE2\x80\x9C")); // complete 3-byte UTF-8 (left double quote)
|
||||
GGML_ASSERT(common_utf8_is_complete("\xF0\x9F\x98\x80")); // complete 4-byte UTF-8 (emoji)
|
||||
GGML_ASSERT(common_utf8_is_complete("abc\xC3\xA9")); // ASCII + complete 2-byte
|
||||
|
||||
// Incomplete sequences
|
||||
GGML_ASSERT(!common_utf8_is_complete(std::string("\xC2", 1))); // 2-byte start, missing continuation
|
||||
GGML_ASSERT(!common_utf8_is_complete(std::string("\xE2\x80", 2))); // 3-byte start + 1 cont, missing 1
|
||||
GGML_ASSERT(!common_utf8_is_complete(std::string("\xE2", 1))); // 3-byte start, missing 2
|
||||
GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0\x9F\x98", 3))); // 4-byte start + 2 cont, missing 1
|
||||
GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0\x9F", 2))); // 4-byte start + 1 cont, missing 2
|
||||
GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0", 1))); // 4-byte start, missing 3
|
||||
GGML_ASSERT(!common_utf8_is_complete(std::string("\x80", 1))); // orphan continuation byte
|
||||
|
||||
// Mixed: ASCII followed by start of multi-byte
|
||||
GGML_ASSERT(!common_utf8_is_complete(std::string("hello\xC3", 6))); // ASCII + incomplete 2-byte
|
||||
GGML_ASSERT(common_utf8_is_complete(std::string("hello\xC3\xA9", 7))); // ASCII + complete 2-byte
|
||||
GGML_ASSERT(common_utf8_is_complete("\xC2\xA0"));
|
||||
GGML_ASSERT(common_utf8_is_complete("\xE2\x80\x9C"));
|
||||
GGML_ASSERT(common_utf8_is_complete("\xF0\x9F\x98\x80"));
|
||||
GGML_ASSERT(common_utf8_is_complete("abc\xC3\xA9"));
|
||||
GGML_ASSERT(!common_utf8_is_complete(std::string("\xC2", 1)));
|
||||
GGML_ASSERT(!common_utf8_is_complete(std::string("\xE2\x80", 2)));
|
||||
GGML_ASSERT(!common_utf8_is_complete(std::string("\xE2", 1)));
|
||||
GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0\x9F\x98", 3)));
|
||||
GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0\x9F", 2)));
|
||||
GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0", 1)));
|
||||
GGML_ASSERT(!common_utf8_is_complete(std::string("\x80", 1)));
|
||||
GGML_ASSERT(!common_utf8_is_complete(std::string("hello\xC3", 6)));
|
||||
GGML_ASSERT(common_utf8_is_complete(std::string("hello\xC3\xA9", 7)));
|
||||
}
|
||||
|
||||
int main(void) {
|
||||
// Reasoning budget sampler tests
|
||||
printf("Testing reasoning budget sampler... ");
|
||||
|
||||
// Test 1: Basic budget with start/end tokens - no forcing (natural end before budget exhausted)
|
||||
// Test 1: Natural end before budget exhausted
|
||||
{
|
||||
const std::vector<llama_token> start = {100}; // start token
|
||||
const std::vector<llama_token> end = {101}; // end token
|
||||
const std::vector<llama_token> forced = {102}; // forced token (not used in this test)
|
||||
const std::vector<llama_token> sequence = {100, 50, 51, 101, 52}; // start, two tokens, end, one more
|
||||
|
||||
test_reasoning_budget("natural end before budget exhausted", sequence, start, end, forced,
|
||||
5, // budget of 5 tokens
|
||||
REASONING_BUDGET_IDLE,
|
||||
SIZE_MAX, SIZE_MAX); // no forcing expected (natural end)
|
||||
std::vector<llama_token> start = {100}, end = {101}, forced = {102}, msg = {};
|
||||
std::vector<llama_token> seq = {100, 50, 51, 101, 52};
|
||||
test_reasoning_budget("natural end before budget exhausted", seq, start, end, forced, msg, 5, 0, REASONING_BUDGET_IDLE, SIZE_MAX, SIZE_MAX);
|
||||
}
|
||||
|
||||
// Test 2: Budget exhausted, forcing should occur
|
||||
// Flow: i=0 apply()->passthrough, accept(100)->COUNTING; i=1 accept(50)->remaining=1
|
||||
// i=2 accept(51)->remaining=0->FORCING; i=3 apply() forces token[0]; i=4 apply() forces token[1]
|
||||
// At i=4, accept() advances force_pos to 2 which equals forced_tokens.size(), so state becomes DONE
|
||||
// Test 2: Budget exhausted, forcing occurs
|
||||
{
|
||||
const std::vector<llama_token> start = {100};
|
||||
const std::vector<llama_token> end = {101};
|
||||
const std::vector<llama_token> forced = {102, 101}; // forced message + end
|
||||
const std::vector<llama_token> sequence = {100, 50, 51, 52, 53}; // start + 4 tokens (budget=2)
|
||||
|
||||
test_reasoning_budget("budget exhausted forcing", sequence, start, end, forced,
|
||||
2, // budget of 2 tokens
|
||||
REASONING_BUDGET_IDLE,
|
||||
3, // forcing starts at i=3 (accept at i=2 depletes budget, apply at i=3 forces)
|
||||
4); // forcing continues through i=4 (accept at i=4 transitions to DONE)
|
||||
std::vector<llama_token> start = {100}, end = {101}, forced = {102, 101}, msg = {};
|
||||
std::vector<llama_token> seq = {100, 50, 51, 52, 53};
|
||||
test_reasoning_budget("budget exhausted forcing", seq, start, end, forced, msg, 2, 0, REASONING_BUDGET_IDLE, 3, 4);
|
||||
}
|
||||
|
||||
// Test 3: Activate immediately with budget=0, forcing should start right away
|
||||
// Flow: init promotes COUNTING+budget=0 to FORCING, so apply() sees FORCING at i=0
|
||||
// Test 3: Budget=0 forces immediately
|
||||
{
|
||||
const std::vector<llama_token> start = {100};
|
||||
const std::vector<llama_token> end = {101};
|
||||
const std::vector<llama_token> forced = {102, 101};
|
||||
const std::vector<llama_token> sequence = {100, 50, 51, 52}; // start token first, then 3 tokens
|
||||
|
||||
test_reasoning_budget("activate immediately budget=0", sequence, start, end, forced,
|
||||
0, // budget of 0 tokens
|
||||
REASONING_BUDGET_COUNTING, // starts counting, promoted to FORCING since budget=0
|
||||
0, // forcing starts at i=0 (initialized in FORCING, apply forces immediately)
|
||||
1); // forcing continues through i=1 (accept at i=1 transitions to DONE)
|
||||
std::vector<llama_token> start = {100}, end = {101}, forced = {102, 101}, msg = {};
|
||||
std::vector<llama_token> seq = {100, 50, 51, 52};
|
||||
test_reasoning_budget("activate immediately budget=0", seq, start, end, forced, msg, 0, 0, REASONING_BUDGET_COUNTING, 0, 1);
|
||||
}
|
||||
|
||||
// Test 4: No start/end tokens configured - passthrough (no forcing)
|
||||
// Test 4: No start/end — passthrough
|
||||
{
|
||||
const std::vector<llama_token> start = {};
|
||||
const std::vector<llama_token> end = {};
|
||||
const std::vector<llama_token> forced = {102};
|
||||
const std::vector<llama_token> sequence = {50, 51, 52, 53};
|
||||
|
||||
test_reasoning_budget("no start/end configured", sequence, start, end, forced,
|
||||
2, // budget
|
||||
REASONING_BUDGET_IDLE,
|
||||
SIZE_MAX, SIZE_MAX); // no forcing (no start/end configured)
|
||||
std::vector<llama_token> start = {}, end = {}, forced = {102}, msg = {};
|
||||
std::vector<llama_token> seq = {50, 51, 52, 53};
|
||||
test_reasoning_budget("no start/end configured", seq, start, end, forced, msg, 2, 0, REASONING_BUDGET_IDLE, SIZE_MAX, SIZE_MAX);
|
||||
}
|
||||
|
||||
// Test 5: Activate immediately with budget > 0, count down then force
|
||||
// Flow: i=0 accept(50)->remaining=1, i=1 accept(51)->remaining=0->FORCING
|
||||
// Forcing starts at i=2 (apply sees FORCING after accept at i=1 transitioned)
|
||||
// Test 5: Start in COUNTING state, count down then force
|
||||
{
|
||||
const std::vector<llama_token> start = {100};
|
||||
const std::vector<llama_token> end = {101};
|
||||
const std::vector<llama_token> forced = {102, 101};
|
||||
const std::vector<llama_token> sequence = {50, 51, 52, 53};
|
||||
|
||||
test_reasoning_budget("activate immediately with budget", sequence, start, end, forced,
|
||||
2, // budget of 2 tokens
|
||||
REASONING_BUDGET_COUNTING,
|
||||
2, // forcing starts at i=2 (after 2 accepts deplete budget, apply at i=2 forces)
|
||||
3); // forcing continues through i=3
|
||||
std::vector<llama_token> start = {100}, end = {101}, forced = {102, 101}, msg = {};
|
||||
std::vector<llama_token> seq = {50, 51, 52, 53};
|
||||
test_reasoning_budget("activate immediately with budget", seq, start, end, forced, msg, 2, 0, REASONING_BUDGET_COUNTING, 2, 3);
|
||||
}
|
||||
|
||||
printf("OK (5 tests passed)\n");
|
||||
// Test 6: Two-phase — model concludes naturally in conclusion window
|
||||
{
|
||||
std::vector<llama_token> start = {100}, end = {101}, forced = {101}, msg = {200};
|
||||
std::vector<llama_token> seq = {100, 50, 51, 200, 101, 52};
|
||||
test_reasoning_budget("two-phase natural end in conclusion window", seq, start, end, forced, msg, 2, 3, REASONING_BUDGET_IDLE, 3, 3);
|
||||
}
|
||||
|
||||
// Test 7: Two-phase — conclusion budget exhausted, safety net fires
|
||||
{
|
||||
std::vector<llama_token> start = {100}, end = {101}, forced = {101}, msg = {200};
|
||||
std::vector<llama_token> seq = {100, 50, 51, 200, 52, 101};
|
||||
test_reasoning_budget("two-phase conclusion budget exhausted safety net fires", seq, start, end, forced, msg, 2, 1, REASONING_BUDGET_IDLE, 3, 5);
|
||||
}
|
||||
|
||||
// Test 8: Two-phase — no message tokens, conclusion only (skips INJECTING)
|
||||
{
|
||||
std::vector<llama_token> start = {100}, end = {101}, forced = {101}, msg = {};
|
||||
std::vector<llama_token> seq = {100, 50, 51, 101, 52};
|
||||
test_reasoning_budget("two-phase no message tokens conclusion only", seq, start, end, forced, msg, 2, 5, REASONING_BUDGET_IDLE, SIZE_MAX, SIZE_MAX);
|
||||
}
|
||||
|
||||
// Test 9: Backward compat — conclusion_budget=0 identical to original
|
||||
{
|
||||
std::vector<llama_token> start = {100}, end = {101}, forced = {102, 101}, msg = {};
|
||||
std::vector<llama_token> seq = {100, 50, 51, 52, 53};
|
||||
test_reasoning_budget("backward compat conclusion_budget=0", seq, start, end, forced, msg, 2, 0, REASONING_BUDGET_IDLE, 3, 4);
|
||||
}
|
||||
|
||||
// Test 10: Two-phase — multi-token message (3 tokens all forced before CONCLUDING)
|
||||
{
|
||||
std::vector<llama_token> start = {100}, end = {101}, forced = {101}, msg = {200, 201, 202};
|
||||
std::vector<llama_token> seq = {100, 50, 51, 200, 201, 202, 101, 52};
|
||||
test_reasoning_budget("two-phase multi-token message injection", seq, start, end, forced, msg, 2, 5, REASONING_BUDGET_IDLE, 3, 5);
|
||||
}
|
||||
|
||||
printf("OK (10 tests passed)\n");
|
||||
|
||||
printf("Testing UTF-8 boundary detection... ");
|
||||
test_utf8_boundary_detection();
|
||||
printf("OK\n");
|
||||
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue