diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 7cd96c5cd3..1a04ac5b11 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -192,6 +192,73 @@ static void test_top_n_sigma(const std::vector & probs, const std::vector tester.check(); } +static void test_dist_rng(uint32_t seed, bool blue_noise, enum llama_rng_type rng_type, + const std::vector & expected, const char * desc) { + const int n_vocab = 16; + const int n_samples = 32; + + // fixed non-uniform distribution: token i has logit log(i+1) + std::vector data(n_vocab); + for (int i = 0; i < n_vocab; i++) { + data[i] = {i, logf((float)(i + 1)), 0.0f}; + } + + auto * sampler = llama_sampler_init_dist_rng(seed, blue_noise, rng_type); + std::vector tokens(n_samples); + + for (int i = 0; i < n_samples; i++) { + std::vector cur(data); + llama_token_data_array cur_p = {cur.data(), cur.size(), -1, false}; + llama_sampler_apply(sampler, &cur_p); + GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (llama_token)n_vocab); + tokens[i] = cur_p.data[cur_p.selected].id; + } + + if (expected.empty()) { + // print sequence for capture + printf("test_dist_rng %s: {", desc); + for (int i = 0; i < n_samples; i++) { + printf("%s%d", i ? ", " : "", tokens[i]); + } + printf("}\n"); + } else { + // verify against known sequence + GGML_ASSERT((int)expected.size() == n_samples); + bool match = true; + for (int i = 0; i < n_samples; i++) { + if (tokens[i] != expected[i]) { + match = false; + break; + } + } + if (!match) { + printf("test_dist_rng %s: MISMATCH\n got: {", desc); + for (int i = 0; i < n_samples; i++) { + printf("%s%d", i ? ", " : "", tokens[i]); + } + printf("}\n expected: {"); + for (int i = 0; i < n_samples; i++) { + printf("%s%d", i ? ", " : "", expected[i]); + } + printf("}\n"); + GGML_ASSERT(false); + } + + // also verify reset reproduces same sequence + llama_sampler_reset(sampler); + for (int i = 0; i < n_samples; i++) { + std::vector cur(data); + llama_token_data_array cur_p = {cur.data(), cur.size(), -1, false}; + llama_sampler_apply(sampler, &cur_p); + GGML_ASSERT(cur_p.data[cur_p.selected].id == tokens[i]); + } + + printf("test_dist_rng %-30s OK\n", desc); + } + + llama_sampler_free(sampler); +} + static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p ) { sampler_tester tester(n_vocab); @@ -392,6 +459,13 @@ int main(void) { test_sampler_queue(10000, "mkp", 100, 0.8f, 0.1f); test_sampler_queue(10000, "mpk", 100, 0.8f, 0.1f); + test_dist_rng(42, false, LLAMA_RNG_TYPE_LOWBIAS32, + {5, 12, 8, 10, 12, 11, 10, 8, 8, 10, 11, 9, 7, 6, 11, 13, 14, 15, 13, 4, 12, 14, 13, 13, 14, 12, 5, 15, 4, 13, 15, 12}, + "lowbias32"); + test_dist_rng(42, true, LLAMA_RNG_TYPE_LOWBIAS32, + {10, 5, 12, 8, 15, 13, 3, 10, 13, 12, 2, 15, 8, 14, 5, 11, 7, 9, 15, 11, 8, 2, 12, 14, 7, 9, 13, 10, 14, 5, 12, 15}, + "lowbias32 + blue noise"); + printf("OK\n"); test_perf();