#include "reasoning-budget.h" #include "common.h" #include "unicode.h" #include "log.h" #include #include #include #include struct token_matcher { std::vector tokens; size_t pos = 0; bool advance(llama_token token) { if (tokens.empty()) { return false; } if (token == tokens[pos]) { pos++; if (pos >= tokens.size()) { pos = 0; return true; } } else { pos = 0; if (token == tokens[0]) { pos = 1; } } return false; } void reset() { pos = 0; } }; struct common_reasoning_budget_ctx { const llama_vocab * vocab; token_matcher start_matcher; token_matcher end_matcher; std::vector forced_tokens; int32_t budget; // maximum tokens in reasoning block int32_t remaining; // tokens remaining in budget common_reasoning_budget_state state; // for forcing size_t force_pos; // next position in forced_tokens to force }; static const char * common_reasoning_budget_name(const struct llama_sampler * /*smpl*/) { return "reasoning-budget"; } static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_token token) { auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx; switch (ctx->state) { case REASONING_BUDGET_IDLE: { if (ctx->start_matcher.advance(token)) { ctx->state = REASONING_BUDGET_COUNTING; ctx->remaining = ctx->budget; LOG_INF("reasoning-budget: activated, budget=%d tokens\n", ctx->budget); if (ctx->remaining <= 0) { ctx->state = REASONING_BUDGET_FORCING; ctx->force_pos = 0; LOG_INF("reasoning-budget: budget=0, forcing immediately\n"); } } break; } case REASONING_BUDGET_COUNTING: case REASONING_BUDGET_WAITING_UTF8: { if (ctx->end_matcher.advance(token)) { ctx->state = REASONING_BUDGET_DONE; LOG_INF("reasoning-budget: deactivated (natural end)\n"); break; } bool utf8_complete = true; if (ctx->vocab != nullptr) { const std::string piece = common_token_to_piece(ctx->vocab, token, false); utf8_complete = common_utf8_is_complete(piece); } if (ctx->state == REASONING_BUDGET_WAITING_UTF8) { if (utf8_complete) { ctx->state = REASONING_BUDGET_FORCING; ctx->force_pos = 0; ctx->end_matcher.reset(); LOG_INF("reasoning-budget: UTF-8 complete, now forcing end sequence\n"); } } else if (ctx->state == REASONING_BUDGET_COUNTING) { ctx->remaining--; if (ctx->remaining <= 0) { if (utf8_complete) { ctx->state = REASONING_BUDGET_FORCING; ctx->force_pos = 0; ctx->end_matcher.reset(); LOG_INF("reasoning-budget: budget exhausted, forcing end sequence\n"); } else { ctx->state = REASONING_BUDGET_WAITING_UTF8; ctx->end_matcher.reset(); LOG_INF("reasoning-budget: budget exhausted, waiting for UTF-8 completion\n"); } } } break; } case REASONING_BUDGET_FORCING: // force_pos is advanced in apply(), not here. // This ensures the first forced token isn't skipped when the sampler // is initialized directly in FORCING state (e.g. COUNTING + budget=0) break; case REASONING_BUDGET_DONE: break; } } static void common_reasoning_budget_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx; if (ctx->state != REASONING_BUDGET_FORCING) { // passthrough — don't modify logits return; } if (ctx->force_pos >= ctx->forced_tokens.size()) { return; } const llama_token forced = ctx->forced_tokens[ctx->force_pos]; // set all logits to -inf except the forced token for (size_t i = 0; i < cur_p->size; i++) { if (cur_p->data[i].id != forced) { cur_p->data[i].logit = -INFINITY; } } // advance to next forced token (done here rather than in accept so that // the first forced token isn't skipped when starting in FORCING state) ctx->force_pos++; if (ctx->force_pos >= ctx->forced_tokens.size()) { ctx->state = REASONING_BUDGET_DONE; LOG_INF("reasoning-budget: forced sequence complete, done\n"); } } static void common_reasoning_budget_reset(struct llama_sampler * smpl) { auto * ctx = (common_reasoning_budget_ctx *) smpl->ctx; ctx->state = REASONING_BUDGET_IDLE; ctx->remaining = ctx->budget; ctx->start_matcher.reset(); ctx->end_matcher.reset(); ctx->force_pos = 0; } static struct llama_sampler * common_reasoning_budget_clone(const struct llama_sampler * smpl) { const auto * ctx = (const common_reasoning_budget_ctx *) smpl->ctx; return common_reasoning_budget_init( ctx->vocab, ctx->start_matcher.tokens, ctx->end_matcher.tokens, ctx->forced_tokens, ctx->budget, ctx->state); } static void common_reasoning_budget_free(struct llama_sampler * smpl) { delete (common_reasoning_budget_ctx *) smpl->ctx; } static struct llama_sampler_i common_reasoning_budget_i = { /* .name = */ common_reasoning_budget_name, /* .accept = */ common_reasoning_budget_accept, /* .apply = */ common_reasoning_budget_apply, /* .reset = */ common_reasoning_budget_reset, /* .clone = */ common_reasoning_budget_clone, /* .free = */ common_reasoning_budget_free, /* .backend_init = */ nullptr, /* .backend_accept = */ nullptr, /* .backend_apply = */ nullptr, /* .backend_set_input = */ nullptr, }; struct llama_sampler * common_reasoning_budget_init( const struct llama_vocab * vocab, const std::vector & start_tokens, const std::vector & end_tokens, const std::vector & forced_tokens, int32_t budget, common_reasoning_budget_state initial_state) { // promote COUNTING with budget <= 0 to FORCING if (initial_state == REASONING_BUDGET_COUNTING && budget <= 0) { initial_state = REASONING_BUDGET_FORCING; } return llama_sampler_init( /* .iface = */ &common_reasoning_budget_i, /* .ctx = */ new common_reasoning_budget_ctx { /* .vocab = */ vocab, /* .start_matcher = */ { start_tokens, 0 }, /* .end_matcher = */ { end_tokens, 0 }, /* .forced_tokens = */ forced_tokens, /* .budget = */ budget, /* .remaining = */ budget, /* .state = */ initial_state, /* .force_pos = */ 0, } ); }