220 lines
7.3 KiB
C++
220 lines
7.3 KiB
C++
#include "reasoning-budget.h"
|
|
#include "common.h"
|
|
#include "unicode.h"
|
|
|
|
#include "log.h"
|
|
|
|
#include <cmath>
|
|
#include <cstdint>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
struct token_matcher {
|
|
std::vector<llama_token> tokens;
|
|
size_t pos = 0;
|
|
|
|
bool advance(llama_token token) {
|
|
if (tokens.empty()) {
|
|
return false;
|
|
}
|
|
|
|
if (token == tokens[pos]) {
|
|
pos++;
|
|
if (pos >= tokens.size()) {
|
|
pos = 0;
|
|
return true;
|
|
}
|
|
} else {
|
|
pos = 0;
|
|
if (token == tokens[0]) {
|
|
pos = 1;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
void reset() { pos = 0; }
|
|
};
|
|
|
|
struct common_reasoning_budget_ctx {
|
|
const llama_vocab * vocab;
|
|
|
|
token_matcher start_matcher;
|
|
token_matcher end_matcher;
|
|
std::vector<llama_token> forced_tokens;
|
|
|
|
int32_t budget; // maximum tokens in reasoning block
|
|
int32_t remaining; // tokens remaining in budget
|
|
|
|
common_reasoning_budget_state state;
|
|
|
|
// for forcing
|
|
size_t force_pos; // next position in forced_tokens to force
|
|
};
|
|
|
|
static const char * common_reasoning_budget_name(const struct llama_sampler * /*smpl*/) {
|
|
return "reasoning-budget";
|
|
}
|
|
|
|
static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_token token) {
|
|
auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
|
|
|
|
switch (ctx->state) {
|
|
case REASONING_BUDGET_IDLE:
|
|
{
|
|
if (ctx->start_matcher.advance(token)) {
|
|
ctx->state = REASONING_BUDGET_COUNTING;
|
|
ctx->remaining = ctx->budget;
|
|
LOG_INF("reasoning-budget: activated, budget=%d tokens\n", ctx->budget);
|
|
|
|
if (ctx->remaining <= 0) {
|
|
ctx->state = REASONING_BUDGET_FORCING;
|
|
ctx->force_pos = 0;
|
|
LOG_INF("reasoning-budget: budget=0, forcing immediately\n");
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
case REASONING_BUDGET_COUNTING:
|
|
case REASONING_BUDGET_WAITING_UTF8:
|
|
{
|
|
if (ctx->end_matcher.advance(token)) {
|
|
ctx->state = REASONING_BUDGET_DONE;
|
|
LOG_INF("reasoning-budget: deactivated (natural end)\n");
|
|
break;
|
|
}
|
|
|
|
bool utf8_complete = true;
|
|
if (ctx->vocab != nullptr) {
|
|
const std::string piece = common_token_to_piece(ctx->vocab, token, false);
|
|
utf8_complete = common_utf8_is_complete(piece);
|
|
}
|
|
|
|
if (ctx->state == REASONING_BUDGET_WAITING_UTF8) {
|
|
if (utf8_complete) {
|
|
ctx->state = REASONING_BUDGET_FORCING;
|
|
ctx->force_pos = 0;
|
|
ctx->end_matcher.reset();
|
|
LOG_INF("reasoning-budget: UTF-8 complete, now forcing end sequence\n");
|
|
}
|
|
} else if (ctx->state == REASONING_BUDGET_COUNTING) {
|
|
ctx->remaining--;
|
|
if (ctx->remaining <= 0) {
|
|
if (utf8_complete) {
|
|
ctx->state = REASONING_BUDGET_FORCING;
|
|
ctx->force_pos = 0;
|
|
ctx->end_matcher.reset();
|
|
LOG_INF("reasoning-budget: budget exhausted, forcing end sequence\n");
|
|
} else {
|
|
ctx->state = REASONING_BUDGET_WAITING_UTF8;
|
|
ctx->end_matcher.reset();
|
|
LOG_INF("reasoning-budget: budget exhausted, waiting for UTF-8 completion\n");
|
|
}
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
case REASONING_BUDGET_FORCING:
|
|
// force_pos is advanced in apply(), not here.
|
|
// This ensures the first forced token isn't skipped when the sampler
|
|
// is initialized directly in FORCING state (e.g. COUNTING + budget=0)
|
|
break;
|
|
case REASONING_BUDGET_DONE:
|
|
break;
|
|
}
|
|
}
|
|
|
|
static void common_reasoning_budget_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
|
auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
|
|
|
|
if (ctx->state != REASONING_BUDGET_FORCING) {
|
|
// passthrough — don't modify logits
|
|
return;
|
|
}
|
|
|
|
if (ctx->force_pos >= ctx->forced_tokens.size()) {
|
|
return;
|
|
}
|
|
|
|
const llama_token forced = ctx->forced_tokens[ctx->force_pos];
|
|
|
|
// set all logits to -inf except the forced token
|
|
for (size_t i = 0; i < cur_p->size; i++) {
|
|
if (cur_p->data[i].id != forced) {
|
|
cur_p->data[i].logit = -INFINITY;
|
|
}
|
|
}
|
|
|
|
// advance to next forced token (done here rather than in accept so that
|
|
// the first forced token isn't skipped when starting in FORCING state)
|
|
ctx->force_pos++;
|
|
if (ctx->force_pos >= ctx->forced_tokens.size()) {
|
|
ctx->state = REASONING_BUDGET_DONE;
|
|
LOG_INF("reasoning-budget: forced sequence complete, done\n");
|
|
}
|
|
}
|
|
|
|
static void common_reasoning_budget_reset(struct llama_sampler * smpl) {
|
|
auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx;
|
|
ctx->state = REASONING_BUDGET_IDLE;
|
|
ctx->remaining = ctx->budget;
|
|
ctx->start_matcher.reset();
|
|
ctx->end_matcher.reset();
|
|
ctx->force_pos = 0;
|
|
}
|
|
|
|
static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) {
|
|
const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx;
|
|
return common_reasoning_budget_init(
|
|
ctx->vocab,
|
|
ctx->start_matcher.tokens,
|
|
ctx->end_matcher.tokens,
|
|
ctx->forced_tokens,
|
|
ctx->budget,
|
|
ctx->state);
|
|
}
|
|
|
|
static void common_reasoning_budget_free(struct llama_sampler * smpl) {
|
|
delete (common_reasoning_budget_ctx *) smpl->ctx;
|
|
}
|
|
|
|
static struct llama_sampler_i common_reasoning_budget_i = {
|
|
/* .name = */ common_reasoning_budget_name,
|
|
/* .accept = */ common_reasoning_budget_accept,
|
|
/* .apply = */ common_reasoning_budget_apply,
|
|
/* .reset = */ common_reasoning_budget_reset,
|
|
/* .clone = */ common_reasoning_budget_clone,
|
|
/* .free = */ common_reasoning_budget_free,
|
|
/* .backend_init = */ nullptr,
|
|
/* .backend_accept = */ nullptr,
|
|
/* .backend_apply = */ nullptr,
|
|
/* .backend_set_input = */ nullptr,
|
|
};
|
|
|
|
struct llama_sampler * common_reasoning_budget_init(
|
|
const struct llama_vocab * vocab,
|
|
const std::vector<llama_token> & start_tokens,
|
|
const std::vector<llama_token> & end_tokens,
|
|
const std::vector<llama_token> & forced_tokens,
|
|
int32_t budget,
|
|
common_reasoning_budget_state initial_state) {
|
|
// promote COUNTING with budget <= 0 to FORCING
|
|
if (initial_state == REASONING_BUDGET_COUNTING && budget <= 0) {
|
|
initial_state = REASONING_BUDGET_FORCING;
|
|
}
|
|
|
|
return llama_sampler_init(
|
|
/* .iface = */ &common_reasoning_budget_i,
|
|
/* .ctx = */ new common_reasoning_budget_ctx {
|
|
/* .vocab = */ vocab,
|
|
/* .start_matcher = */ { start_tokens, 0 },
|
|
/* .end_matcher = */ { end_tokens, 0 },
|
|
/* .forced_tokens = */ forced_tokens,
|
|
/* .budget = */ budget,
|
|
/* .remaining = */ budget,
|
|
/* .state = */ initial_state,
|
|
/* .force_pos = */ 0,
|
|
}
|
|
);
|
|
}
|