239 lines
11 KiB
C++
239 lines
11 KiB
C++
#include "reasoning-budget.h"
|
|
#include "unicode.h"
|
|
|
|
#include "llama.h"
|
|
#include "ggml.h"
|
|
|
|
#ifdef NDEBUG
|
|
#undef NDEBUG
|
|
#endif
|
|
|
|
#include <cmath>
|
|
#include <cstddef>
|
|
#include <cstdio>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
// Reasoning budget sampler test helper
|
|
// These tests use nullptr vocab which safely falls back to treating all tokens as complete
|
|
// (The UTF-8 boundary detection logic is tested separately in test_utf8_boundary_detection)
|
|
static void test_reasoning_budget(
|
|
const char * test_name,
|
|
const std::vector<llama_token> & sequence,
|
|
const std::vector<llama_token> & start_tokens,
|
|
const std::vector<llama_token> & end_tokens,
|
|
const std::vector<llama_token> & forced_tokens,
|
|
int32_t budget,
|
|
common_reasoning_budget_state initial_state,
|
|
size_t expected_force_start, // token index where forcing should start (SIZE_MAX = never)
|
|
size_t expected_force_end // token index where forcing should end (after this, no more forcing)
|
|
) {
|
|
// Find the maximum token ID to ensure our vocab covers all tokens
|
|
llama_token max_token = 0;
|
|
for (auto t : sequence) max_token = std::max(max_token, t);
|
|
for (auto t : start_tokens) max_token = std::max(max_token, t);
|
|
for (auto t : end_tokens) max_token = std::max(max_token, t);
|
|
for (auto t : forced_tokens) max_token = std::max(max_token, t);
|
|
|
|
// Create a minimal sampler with mock vocabulary
|
|
// For this test, we use nullptr as vocab since we're testing state transitions
|
|
// The UTF-8 boundary check will treat all tokens as complete (safe fallback)
|
|
auto * sampler = common_reasoning_budget_init(
|
|
nullptr, // vocab - not used for basic state machine tests
|
|
start_tokens,
|
|
end_tokens,
|
|
forced_tokens,
|
|
budget,
|
|
initial_state
|
|
);
|
|
|
|
// Create a test token data array for checking forcing behavior
|
|
// Vocab size must be large enough to include all tokens (start, end, forced, sequence)
|
|
std::vector<llama_token_data> cur;
|
|
const size_t n_vocab = (size_t)max_token + 1;
|
|
for (size_t i = 0; i < n_vocab; i++) {
|
|
cur.emplace_back(llama_token_data{(llama_token)i, logf((float)(i+1)), 0.0f});
|
|
}
|
|
llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
|
|
|
|
size_t actual_force_start = SIZE_MAX;
|
|
size_t actual_force_end = SIZE_MAX;
|
|
|
|
// Feed the sequence and track when forcing occurs
|
|
for (size_t i = 0; i < sequence.size(); i++) {
|
|
llama_sampler_accept(sampler, sequence[i]);
|
|
|
|
// Check if we're in forcing state by applying and seeing if logits are modified
|
|
cur_p.selected = -1;
|
|
for (size_t j = 0; j < cur.size(); j++) {
|
|
cur[j].logit = logf((float)(j+1)); // reset logits
|
|
}
|
|
|
|
llama_sampler_apply(sampler, &cur_p);
|
|
|
|
// Check if forcing is active (all logits except one should be -INFINITY)
|
|
size_t finite_count = 0;
|
|
llama_token finite_token = -1;
|
|
for (size_t j = 0; j < cur.size(); j++) {
|
|
if (std::isfinite(cur[j].logit)) {
|
|
finite_count++;
|
|
finite_token = cur[j].id;
|
|
}
|
|
}
|
|
|
|
fprintf(stderr, " i=%zu: token=%d, finite_count=%zu, finite_token=%d\n", i, (int)sequence[i], finite_count, (int)finite_token);
|
|
|
|
if (finite_count == 1) {
|
|
if (actual_force_start == SIZE_MAX) {
|
|
actual_force_start = i;
|
|
}
|
|
actual_force_end = i;
|
|
} else if (actual_force_start != SIZE_MAX && actual_force_end != SIZE_MAX) {
|
|
// Forcing stopped
|
|
break;
|
|
}
|
|
}
|
|
|
|
llama_sampler_free(sampler);
|
|
|
|
// Verify forcing occurred at expected positions
|
|
if (expected_force_start == SIZE_MAX) {
|
|
if (actual_force_start != SIZE_MAX) {
|
|
fprintf(stderr, "Test '%s' FAILED: Expected no forcing, but forcing occurred at %zu\n", test_name, actual_force_start);
|
|
GGML_ASSERT(false && "Expected no forcing, but forcing occurred");
|
|
}
|
|
} else {
|
|
if (actual_force_start == SIZE_MAX) {
|
|
fprintf(stderr, "Test '%s' FAILED: Expected forcing but none occurred\n", test_name);
|
|
GGML_ASSERT(false && "Expected forcing but none occurred");
|
|
}
|
|
if (actual_force_start != expected_force_start) {
|
|
fprintf(stderr, "Test '%s' FAILED: Forcing started at %zu, expected %zu\n", test_name, actual_force_start, expected_force_start);
|
|
GGML_ASSERT(false && "Forcing started at wrong position");
|
|
}
|
|
}
|
|
|
|
if (expected_force_end != SIZE_MAX) {
|
|
if (actual_force_end < expected_force_end) {
|
|
fprintf(stderr, "Test '%s' FAILED: Forcing ended at %zu, expected >= %zu\n", test_name, actual_force_end, expected_force_end);
|
|
GGML_ASSERT(false && "Forcing ended too early");
|
|
}
|
|
}
|
|
|
|
fprintf(stderr, " Test '%s' passed (force_start=%zu, force_end=%zu)\n", test_name, actual_force_start, actual_force_end);
|
|
(void)sequence;
|
|
}
|
|
|
|
// UTF-8 boundary detection unit test
|
|
// Tests common_utf8_is_complete() from reasoning-budget.h
|
|
static void test_utf8_boundary_detection() {
|
|
// Complete sequences
|
|
GGML_ASSERT(common_utf8_is_complete("hello"));
|
|
GGML_ASSERT(common_utf8_is_complete(""));
|
|
GGML_ASSERT(common_utf8_is_complete("\xC2\xA0")); // complete 2-byte UTF-8 (U+00A0)
|
|
GGML_ASSERT(common_utf8_is_complete("\xE2\x80\x9C")); // complete 3-byte UTF-8 (left double quote)
|
|
GGML_ASSERT(common_utf8_is_complete("\xF0\x9F\x98\x80")); // complete 4-byte UTF-8 (emoji)
|
|
GGML_ASSERT(common_utf8_is_complete("abc\xC3\xA9")); // ASCII + complete 2-byte
|
|
|
|
// Incomplete sequences
|
|
GGML_ASSERT(!common_utf8_is_complete(std::string("\xC2", 1))); // 2-byte start, missing continuation
|
|
GGML_ASSERT(!common_utf8_is_complete(std::string("\xE2\x80", 2))); // 3-byte start + 1 cont, missing 1
|
|
GGML_ASSERT(!common_utf8_is_complete(std::string("\xE2", 1))); // 3-byte start, missing 2
|
|
GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0\x9F\x98", 3))); // 4-byte start + 2 cont, missing 1
|
|
GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0\x9F", 2))); // 4-byte start + 1 cont, missing 2
|
|
GGML_ASSERT(!common_utf8_is_complete(std::string("\xF0", 1))); // 4-byte start, missing 3
|
|
GGML_ASSERT(!common_utf8_is_complete(std::string("\x80", 1))); // orphan continuation byte
|
|
|
|
// Mixed: ASCII followed by start of multi-byte
|
|
GGML_ASSERT(!common_utf8_is_complete(std::string("hello\xC3", 6))); // ASCII + incomplete 2-byte
|
|
GGML_ASSERT(common_utf8_is_complete(std::string("hello\xC3\xA9", 7))); // ASCII + complete 2-byte
|
|
}
|
|
|
|
int main(void) {
|
|
// Reasoning budget sampler tests
|
|
printf("Testing reasoning budget sampler... ");
|
|
|
|
// Test 1: Basic budget with start/end tokens - no forcing (natural end before budget exhausted)
|
|
{
|
|
const std::vector<llama_token> start = {100}; // start token
|
|
const std::vector<llama_token> end = {101}; // end token
|
|
const std::vector<llama_token> forced = {102}; // forced token (not used in this test)
|
|
const std::vector<llama_token> sequence = {100, 50, 51, 101, 52}; // start, two tokens, end, one more
|
|
|
|
test_reasoning_budget("natural end before budget exhausted", sequence, start, end, forced,
|
|
5, // budget of 5 tokens
|
|
REASONING_BUDGET_IDLE,
|
|
SIZE_MAX, SIZE_MAX); // no forcing expected (natural end)
|
|
}
|
|
|
|
// Test 2: Budget exhausted, forcing should occur
|
|
// Flow: i=0 accept(100)->COUNTING, i=1 accept(50)->remaining=1, i=2 accept(51)->remaining=0->FORCING
|
|
// Forcing is active at i=2 and i=3 (when apply() is called while in FORCING state)
|
|
// At i=4, force_pos becomes 2 which equals forced_tokens.size(), so state becomes DONE
|
|
{
|
|
const std::vector<llama_token> start = {100};
|
|
const std::vector<llama_token> end = {101};
|
|
const std::vector<llama_token> forced = {102, 101}; // forced message + end
|
|
const std::vector<llama_token> sequence = {100, 50, 51, 52, 53}; // start + 4 tokens (budget=2)
|
|
|
|
test_reasoning_budget("budget exhausted forcing", sequence, start, end, forced,
|
|
2, // budget of 2 tokens
|
|
REASONING_BUDGET_IDLE,
|
|
2, // forcing starts at i=2 (after accept(51) depletes budget, apply() forces)
|
|
3); // forcing continues through i=3 (at i=4 state becomes DONE)
|
|
}
|
|
|
|
// Test 3: Activate immediately with budget=0, forcing should start right away
|
|
// Flow: Since no start token in sequence, state stays IDLE (no start/end configured means passthrough)
|
|
// This test needs start token to be in the sequence or use activate_immediately with start token present
|
|
{
|
|
const std::vector<llama_token> start = {100};
|
|
const std::vector<llama_token> end = {101};
|
|
const std::vector<llama_token> forced = {102, 101};
|
|
const std::vector<llama_token> sequence = {100, 50, 51, 52}; // start token first, then 3 tokens
|
|
|
|
test_reasoning_budget("activate immediately budget=0", sequence, start, end, forced,
|
|
0, // budget of 0 tokens
|
|
REASONING_BUDGET_COUNTING, // starts counting, promoted to FORCING since budget=0
|
|
0, // forcing starts at i=0 (after accept(100), budget=0 goes straight to FORCING)
|
|
1); // forcing continues through i=1 (at i=2 state becomes DONE)
|
|
}
|
|
|
|
// Test 4: No start/end tokens configured - passthrough (no forcing)
|
|
{
|
|
const std::vector<llama_token> start = {};
|
|
const std::vector<llama_token> end = {};
|
|
const std::vector<llama_token> forced = {102};
|
|
const std::vector<llama_token> sequence = {50, 51, 52, 53};
|
|
|
|
test_reasoning_budget("no start/end configured", sequence, start, end, forced,
|
|
2, // budget
|
|
REASONING_BUDGET_IDLE,
|
|
SIZE_MAX, SIZE_MAX); // no forcing (no start/end configured)
|
|
}
|
|
|
|
// Test 5: Activate immediately with budget > 0, count down then force
|
|
// Flow: i=0 accept(50)->remaining=1, i=1 accept(51)->remaining=0->FORCING
|
|
// So forcing starts at i=1 (apply after accept sees FORCING with force_pos=0)
|
|
{
|
|
const std::vector<llama_token> start = {100};
|
|
const std::vector<llama_token> end = {101};
|
|
const std::vector<llama_token> forced = {102, 101};
|
|
const std::vector<llama_token> sequence = {50, 51, 52, 53};
|
|
|
|
test_reasoning_budget("activate immediately with budget", sequence, start, end, forced,
|
|
2, // budget of 2 tokens
|
|
REASONING_BUDGET_COUNTING,
|
|
1, // forcing starts at i=1 (after 2 accepts deplete budget)
|
|
2); // forcing continues through i=2
|
|
}
|
|
|
|
printf("OK (5 tests passed)\n");
|
|
|
|
printf("Testing UTF-8 boundary detection... ");
|
|
test_utf8_boundary_detection();
|
|
printf("OK\n");
|
|
|
|
return 0;
|
|
}
|