sampling : add rng test case
This commit is contained in:
parent
10179a636d
commit
1f42650078
|
|
@ -192,6 +192,73 @@ static void test_top_n_sigma(const std::vector<float> & probs, const std::vector
|
|||
tester.check();
|
||||
}
|
||||
|
||||
static void test_dist_rng(uint32_t seed, bool blue_noise, enum llama_rng_type rng_type,
|
||||
const std::vector<llama_token> & expected, const char * desc) {
|
||||
const int n_vocab = 16;
|
||||
const int n_samples = 32;
|
||||
|
||||
// fixed non-uniform distribution: token i has logit log(i+1)
|
||||
std::vector<llama_token_data> data(n_vocab);
|
||||
for (int i = 0; i < n_vocab; i++) {
|
||||
data[i] = {i, logf((float)(i + 1)), 0.0f};
|
||||
}
|
||||
|
||||
auto * sampler = llama_sampler_init_dist_rng(seed, blue_noise, rng_type);
|
||||
std::vector<llama_token> tokens(n_samples);
|
||||
|
||||
for (int i = 0; i < n_samples; i++) {
|
||||
std::vector<llama_token_data> cur(data);
|
||||
llama_token_data_array cur_p = {cur.data(), cur.size(), -1, false};
|
||||
llama_sampler_apply(sampler, &cur_p);
|
||||
GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (llama_token)n_vocab);
|
||||
tokens[i] = cur_p.data[cur_p.selected].id;
|
||||
}
|
||||
|
||||
if (expected.empty()) {
|
||||
// print sequence for capture
|
||||
printf("test_dist_rng %s: {", desc);
|
||||
for (int i = 0; i < n_samples; i++) {
|
||||
printf("%s%d", i ? ", " : "", tokens[i]);
|
||||
}
|
||||
printf("}\n");
|
||||
} else {
|
||||
// verify against known sequence
|
||||
GGML_ASSERT((int)expected.size() == n_samples);
|
||||
bool match = true;
|
||||
for (int i = 0; i < n_samples; i++) {
|
||||
if (tokens[i] != expected[i]) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!match) {
|
||||
printf("test_dist_rng %s: MISMATCH\n got: {", desc);
|
||||
for (int i = 0; i < n_samples; i++) {
|
||||
printf("%s%d", i ? ", " : "", tokens[i]);
|
||||
}
|
||||
printf("}\n expected: {");
|
||||
for (int i = 0; i < n_samples; i++) {
|
||||
printf("%s%d", i ? ", " : "", expected[i]);
|
||||
}
|
||||
printf("}\n");
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
|
||||
// also verify reset reproduces same sequence
|
||||
llama_sampler_reset(sampler);
|
||||
for (int i = 0; i < n_samples; i++) {
|
||||
std::vector<llama_token_data> cur(data);
|
||||
llama_token_data_array cur_p = {cur.data(), cur.size(), -1, false};
|
||||
llama_sampler_apply(sampler, &cur_p);
|
||||
GGML_ASSERT(cur_p.data[cur_p.selected].id == tokens[i]);
|
||||
}
|
||||
|
||||
printf("test_dist_rng %-30s OK\n", desc);
|
||||
}
|
||||
|
||||
llama_sampler_free(sampler);
|
||||
}
|
||||
|
||||
static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p
|
||||
) {
|
||||
sampler_tester tester(n_vocab);
|
||||
|
|
@ -392,6 +459,13 @@ int main(void) {
|
|||
test_sampler_queue(10000, "mkp", 100, 0.8f, 0.1f);
|
||||
test_sampler_queue(10000, "mpk", 100, 0.8f, 0.1f);
|
||||
|
||||
test_dist_rng(42, false, LLAMA_RNG_TYPE_LOWBIAS32,
|
||||
{5, 12, 8, 10, 12, 11, 10, 8, 8, 10, 11, 9, 7, 6, 11, 13, 14, 15, 13, 4, 12, 14, 13, 13, 14, 12, 5, 15, 4, 13, 15, 12},
|
||||
"lowbias32");
|
||||
test_dist_rng(42, true, LLAMA_RNG_TYPE_LOWBIAS32,
|
||||
{10, 5, 12, 8, 15, 13, 3, 10, 13, 12, 2, 15, 8, 14, 5, 11, 7, 9, 15, 11, 8, 2, 12, 14, 7, 9, 13, 10, 14, 5, 12, 15},
|
||||
"lowbias32 + blue noise");
|
||||
|
||||
printf("OK\n");
|
||||
|
||||
test_perf();
|
||||
|
|
|
|||
Loading…
Reference in New Issue