llama.cpp/common/reasoning-budget.cpp

220 lines
7.3 KiB
C++

#include "reasoning-budget.h"
#include "common.h"
#include "unicode.h"
#include "log.h"
#include <cmath>
#include <cstdint>
#include <string>
#include <vector>
struct token_matcher {
std::vector<llama_token> tokens;
size_t pos = 0;
bool advance(llama_token token) {
if (tokens.empty()) {
return false;
}
if (token == tokens[pos]) {
pos++;
if (pos >= tokens.size()) {
pos = 0;
return true;
}
} else {
pos = 0;
if (token == tokens[0]) {
pos = 1;
}
}
return false;
}
void reset() { pos = 0; }
};
struct common_reasoning_budget_ctx {
const llama_vocab * vocab;
token_matcher start_matcher;
token_matcher end_matcher;
std::vector<llama_token> forced_tokens;
int32_t budget; // maximum tokens in reasoning block
int32_t remaining; // tokens remaining in budget
common_reasoning_budget_state state;
// for forcing
size_t force_pos; // next position in forced_tokens to force
};
static const char * common_reasoning_budget_name(const struct llama_sampler * /*smpl*/) {
return "reasoning-budget";
}
static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_token token) {
auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
switch (ctx->state) {
case REASONING_BUDGET_IDLE:
{
if (ctx->start_matcher.advance(token)) {
ctx->state = REASONING_BUDGET_COUNTING;
ctx->remaining = ctx->budget;
LOG_INF("reasoning-budget: activated, budget=%d tokens\n", ctx->budget);
if (ctx->remaining <= 0) {
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
LOG_INF("reasoning-budget: budget=0, forcing immediately\n");
}
}
break;
}
case REASONING_BUDGET_COUNTING:
case REASONING_BUDGET_WAITING_UTF8:
{
if (ctx->end_matcher.advance(token)) {
ctx->state = REASONING_BUDGET_DONE;
LOG_INF("reasoning-budget: deactivated (natural end)\n");
break;
}
bool utf8_complete = true;
if (ctx->vocab != nullptr) {
const std::string piece = common_token_to_piece(ctx->vocab, token, false);
utf8_complete = common_utf8_is_complete(piece);
}
if (ctx->state == REASONING_BUDGET_WAITING_UTF8) {
if (utf8_complete) {
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: UTF-8 complete, now forcing end sequence\n");
}
} else if (ctx->state == REASONING_BUDGET_COUNTING) {
ctx->remaining--;
if (ctx->remaining <= 0) {
if (utf8_complete) {
ctx->state = REASONING_BUDGET_FORCING;
ctx->force_pos = 0;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: budget exhausted, forcing end sequence\n");
} else {
ctx->state = REASONING_BUDGET_WAITING_UTF8;
ctx->end_matcher.reset();
LOG_INF("reasoning-budget: budget exhausted, waiting for UTF-8 completion\n");
}
}
}
break;
}
case REASONING_BUDGET_FORCING:
// force_pos is advanced in apply(), not here.
// This ensures the first forced token isn't skipped when the sampler
// is initialized directly in FORCING state (e.g. COUNTING + budget=0)
break;
case REASONING_BUDGET_DONE:
break;
}
}
static void common_reasoning_budget_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
if (ctx->state != REASONING_BUDGET_FORCING) {
// passthrough — don't modify logits
return;
}
if (ctx->force_pos >= ctx->forced_tokens.size()) {
return;
}
const llama_token forced = ctx->forced_tokens[ctx->force_pos];
// set all logits to -inf except the forced token
for (size_t i = 0; i < cur_p->size; i++) {
if (cur_p->data[i].id != forced) {
cur_p->data[i].logit = -INFINITY;
}
}
// advance to next forced token (done here rather than in accept so that
// the first forced token isn't skipped when starting in FORCING state)
ctx->force_pos++;
if (ctx->force_pos >= ctx->forced_tokens.size()) {
ctx->state = REASONING_BUDGET_DONE;
LOG_INF("reasoning-budget: forced sequence complete, done\n");
}
}
static void common_reasoning_budget_reset(struct llama_sampler * smpl) {
auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
ctx->state = REASONING_BUDGET_IDLE;
ctx->remaining = ctx->budget;
ctx->start_matcher.reset();
ctx->end_matcher.reset();
ctx->force_pos = 0;
}
static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx;
return common_reasoning_budget_init(
ctx->vocab,
ctx->start_matcher.tokens,
ctx->end_matcher.tokens,
ctx->forced_tokens,
ctx->budget,
ctx->state);
}
static void common_reasoning_budget_free(struct llama_sampler * smpl) {
delete (common_reasoning_budget_ctx *) smpl->ctx;
}
static struct llama_sampler_i common_reasoning_budget_i = {
/* .name = */ common_reasoning_budget_name,
/* .accept = */ common_reasoning_budget_accept,
/* .apply = */ common_reasoning_budget_apply,
/* .reset = */ common_reasoning_budget_reset,
/* .clone = */ common_reasoning_budget_clone,
/* .free = */ common_reasoning_budget_free,
/* .backend_init = */ nullptr,
/* .backend_accept = */ nullptr,
/* .backend_apply = */ nullptr,
/* .backend_set_input = */ nullptr,
};
struct llama_sampler * common_reasoning_budget_init(
const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
int32_t budget,
common_reasoning_budget_state initial_state) {
// promote COUNTING with budget <= 0 to FORCING
if (initial_state == REASONING_BUDGET_COUNTING && budget <= 0) {
initial_state = REASONING_BUDGET_FORCING;
}
return llama_sampler_init(
/* .iface = */ &common_reasoning_budget_i,
/* .ctx = */ new common_reasoning_budget_ctx {
/* .vocab = */ vocab,
/* .start_matcher = */ { start_tokens, 0 },
/* .end_matcher = */ { end_tokens, 0 },
/* .forced_tokens = */ forced_tokens,
/* .budget = */ budget,
/* .remaining = */ budget,
/* .state = */ initial_state,
/* .force_pos = */ 0,
}
);
}