llama.cpp/common/reasoning-budget.h

64 lines
3.4 KiB
C++

#pragma once
#include "llama.h"
#include <cstdint>
#include <vector>
enum common_reasoning_budget_state {
REASONING_BUDGET_IDLE, // waiting for start sequence
REASONING_BUDGET_COUNTING, // counting down tokens
REASONING_BUDGET_INJECTING, // forcing message tokens before conclusion phase
REASONING_BUDGET_CONCLUDING, // conclusion phase: model free to conclude naturally
REASONING_BUDGET_FORCING, // forcing end sequence (hard cutoff safety net)
REASONING_BUDGET_WAITING_UTF8, // budget exhausted, waiting for UTF-8 completion
REASONING_BUDGET_DONE, // passthrough forever
};
// Creates a reasoning budget sampler that limits token generation inside a
// reasoning block (e.g. between <think> and </think>).
//
// State machine: IDLE -> COUNTING -> CONCLUDING -> WAITING_UTF8 -> FORCING -> DONE
// IDLE: passthrough, watching for start_tokens sequence
// COUNTING: counting down remaining tokens, watching for natural end_tokens
// CONCLUDING: conclusion phase after message injection; model generates freely
// until it produces end_tokens naturally or conclusion_budget runs out
// WAITING_UTF8: budget exhausted, allowing tokens to complete a UTF-8 sequence
// FORCING: forces forced_tokens token-by-token (all other logits -> -inf)
// DONE: passthrough forever
//
// Parameters:
// vocab - vocabulary (used for UTF-8 boundary detection; can be nullptr)
// start_tokens - token sequence that activates counting
// end_tokens - token sequence for natural deactivation
// forced_tokens - token sequence forced when budget expires (hard-cutoff safety net)
// budget - max tokens allowed in the thinking phase
// conclusion_budget - tokens reserved for conclusion phase (0 = disabled, original behavior)
// prefill_tokens - tokens already present in the prompt (generation prompt);
// used to determine the initial state: COUNTING if they begin
// with start_tokens (but don't also end with end_tokens),
// IDLE otherwise. COUNTING with budget <= 0 is promoted to FORCING.
//
struct llama_sampler * common_reasoning_budget_init(
const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
const std::vector<llama_token> & message_tokens,
int32_t budget,
int32_t conclusion_budget = 0,
const std::vector<llama_token> & prefill_tokens = {});
// Variant that takes an explicit initial state (used by tests and clone).
// COUNTING with budget <= 0 is promoted to FORCING.
struct llama_sampler * common_reasoning_budget_init(
const struct llama_vocab * vocab,
const std::vector<llama_token> & start_tokens,
const std::vector<llama_token> & end_tokens,
const std::vector<llama_token> & forced_tokens,
const std::vector<llama_token> & message_tokens,
int32_t budget,
int32_t conclusion_budget,
common_reasoning_budget_state initial_state);
common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl);