sampling : fix greedy

This commit is contained in:
Georgi Gerganov 2025-12-11 13:37:02 +02:00
parent 8544aba37f
commit 74b112e3e7
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 52 additions and 9 deletions

View File

@ -920,6 +920,30 @@ static const char * llama_sampler_greedy_name(const struct llama_sampler * smpl)
return sctx->get_name();
}
static void llama_sampler_greedy_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_greedy *) smpl->ctx;
GGML_UNUSED(ctx);
}
static struct llama_sampler * llama_sampler_greedy_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_greedy *) smpl->ctx;
auto * result = llama_sampler_init_greedy();
// copy the state
{
auto * result_ctx = (llama_sampler_greedy *) result->ctx;
GGML_UNUSED(ctx);
GGML_UNUSED(result_ctx);
}
return result;
}
static void llama_sampler_greedy_free(struct llama_sampler * smpl) {
delete (llama_sampler_greedy *) smpl->ctx;
}
static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
cur_p->selected = 0;
for (size_t i = 1; i < cur_p->size; ++i) {
@ -959,9 +983,9 @@ static struct llama_sampler_i llama_sampler_greedy_i = {
/* .name = */ llama_sampler_greedy_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_greedy_apply,
/* .reset = */ nullptr,
/* .clone = */ nullptr,
/* .free = */ nullptr,
/* .reset = */ llama_sampler_greedy_reset,
/* .clone = */ llama_sampler_greedy_clone,
/* .free = */ llama_sampler_greedy_free,
/* .backend_init = */ llama_sampler_greedy_backend_init,
/* .backend_accept = */ nullptr,
/* .backend_apply = */ llama_sampler_greedy_backend_apply,
@ -1069,6 +1093,12 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da
#endif
}
static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_dist *) smpl->ctx;
ctx->seed_cur = get_rng_seed(ctx->seed);
ctx->rng.seed(ctx->seed_cur);
}
static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
auto * result = llama_sampler_init_dist(ctx->seed);
@ -1083,12 +1113,6 @@ static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sample
return result;
}
static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_dist *) smpl->ctx;
ctx->seed_cur = get_rng_seed(ctx->seed);
ctx->rng.seed(ctx->seed_cur);
}
static void llama_sampler_dist_free(struct llama_sampler * smpl) {
delete (llama_sampler_dist *) smpl->ctx;
}

View File

@ -314,6 +314,8 @@ static void test_backend_greedy_sampling(const char * model_path) {
GGML_ASSERT(false && "Failed to decode token");
}
}
llama_sampler_free(backend_sampler_chain);
}
static void test_backend_top_k_sampling(const char * model_path) {
@ -349,6 +351,8 @@ static void test_backend_top_k_sampling(const char * model_path) {
test_ctx.token_to_piece(candidates[i], false).c_str());
}
llama_sampler_free(backend_sampler_chain);
// Sample using CPU sampler for verification that it is possible to do hybrid
// sampling, first top_k on the backend and then dist on the CPU.
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
@ -392,6 +396,9 @@ static void test_backend_temp_sampling(const char * model_path) {
GGML_ASSERT(false && "Failed to decode token");
}
llama_sampler_free(backend_sampler_chain_0);
llama_sampler_free(backend_sampler_chain_1);
// Verfify sequence 0
{
int32_t batch_idx = test_ctx.idx_for_seq(0);
@ -457,6 +464,8 @@ static void test_backend_temp_sampling(const char * model_path) {
uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
GGML_ASSERT(n_logits == 1);
llama_sampler_free(backend_sampler_chain);
};
test_argmax_temp(0.0f);
@ -496,6 +505,8 @@ static void test_backend_temp_ext_sampling(const char * model_path) {
int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
GGML_ASSERT(n_logits == test_ctx.n_vocab);
}
llama_sampler_free(backend_sampler_chain);
}
test_ctx.reset();
@ -532,6 +543,8 @@ static void test_backend_temp_ext_sampling(const char * model_path) {
} else {
GGML_ASSERT(n_logits == (uint32_t) test_ctx.n_vocab);
}
llama_sampler_free(backend_sampler_chain);
};
test_argmax_temp(0.0f, 0.3f, 1.0f); // Greedy (temp=0)
@ -597,6 +610,7 @@ static void test_backend_min_p_sampling(const char * model_path) {
printf("min-p sampling test PASSED\n");
llama_sampler_free(backend_sampler_chain);
llama_sampler_free(chain);
}
@ -653,6 +667,7 @@ static void test_backend_top_p_sampling(const char * model_path) {
printf("top-p sampling test PASSED\n");
llama_sampler_free(backend_sampler_chain);
llama_sampler_free(chain);
}
@ -723,6 +738,9 @@ static void test_backend_multi_sequence_sampling(const char * model_path) {
}
}
llama_sampler_free(sampler_chain_0);
llama_sampler_free(sampler_chain_1);
printf("backend multi-sequence sampling test PASSED\n");
}
@ -755,6 +773,7 @@ static void test_backend_dist_sampling(const char * model_path) {
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
llama_sampler_free(backend_sampler_chain);
printf("backend dist sampling test PASSED\n");
}