sampling : add rng test case

This commit is contained in:
Jan Boon 2026-02-09 15:13:56 +00:00
parent 10179a636d
commit 1f42650078
1 changed files with 74 additions and 0 deletions

View File

@ -192,6 +192,73 @@ static void test_top_n_sigma(const std::vector<float> & probs, const std::vector
tester.check();
}
static void test_dist_rng(uint32_t seed, bool blue_noise, enum llama_rng_type rng_type,
const std::vector<llama_token> & expected, const char * desc) {
const int n_vocab = 16;
const int n_samples = 32;
// fixed non-uniform distribution: token i has logit log(i+1)
std::vector<llama_token_data> data(n_vocab);
for (int i = 0; i < n_vocab; i++) {
data[i] = {i, logf((float)(i + 1)), 0.0f};
}
auto * sampler = llama_sampler_init_dist_rng(seed, blue_noise, rng_type);
std::vector<llama_token> tokens(n_samples);
for (int i = 0; i < n_samples; i++) {
std::vector<llama_token_data> cur(data);
llama_token_data_array cur_p = {cur.data(), cur.size(), -1, false};
llama_sampler_apply(sampler, &cur_p);
GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (llama_token)n_vocab);
tokens[i] = cur_p.data[cur_p.selected].id;
}
if (expected.empty()) {
// print sequence for capture
printf("test_dist_rng %s: {", desc);
for (int i = 0; i < n_samples; i++) {
printf("%s%d", i ? ", " : "", tokens[i]);
}
printf("}\n");
} else {
// verify against known sequence
GGML_ASSERT((int)expected.size() == n_samples);
bool match = true;
for (int i = 0; i < n_samples; i++) {
if (tokens[i] != expected[i]) {
match = false;
break;
}
}
if (!match) {
printf("test_dist_rng %s: MISMATCH\n got: {", desc);
for (int i = 0; i < n_samples; i++) {
printf("%s%d", i ? ", " : "", tokens[i]);
}
printf("}\n expected: {");
for (int i = 0; i < n_samples; i++) {
printf("%s%d", i ? ", " : "", expected[i]);
}
printf("}\n");
GGML_ASSERT(false);
}
// also verify reset reproduces same sequence
llama_sampler_reset(sampler);
for (int i = 0; i < n_samples; i++) {
std::vector<llama_token_data> cur(data);
llama_token_data_array cur_p = {cur.data(), cur.size(), -1, false};
llama_sampler_apply(sampler, &cur_p);
GGML_ASSERT(cur_p.data[cur_p.selected].id == tokens[i]);
}
printf("test_dist_rng %-30s OK\n", desc);
}
llama_sampler_free(sampler);
}
static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p
) {
sampler_tester tester(n_vocab);
@ -392,6 +459,13 @@ int main(void) {
test_sampler_queue(10000, "mkp", 100, 0.8f, 0.1f);
test_sampler_queue(10000, "mpk", 100, 0.8f, 0.1f);
test_dist_rng(42, false, LLAMA_RNG_TYPE_LOWBIAS32,
{5, 12, 8, 10, 12, 11, 10, 8, 8, 10, 11, 9, 7, 6, 11, 13, 14, 15, 13, 4, 12, 14, 13, 13, 14, 12, 5, 15, 4, 13, 15, 12},
"lowbias32");
test_dist_rng(42, true, LLAMA_RNG_TYPE_LOWBIAS32,
{10, 5, 12, 8, 15, 13, 3, 10, 13, 12, 2, 15, 8, 14, 5, 11, 7, 9, 15, 11, 8, 2, 12, 14, 7, 9, 13, 10, 14, 5, 12, 15},
"lowbias32 + blue noise");
printf("OK\n");
test_perf();