sampling : fix greedy

This commit is contained in:
Georgi Gerganov 2025-12-11 13:37:02 +02:00
parent 8544aba37f
commit 74b112e3e7
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 52 additions and 9 deletions

View File

@ -920,6 +920,30 @@ static const char * llama_sampler_greedy_name(const struct llama_sampler * smpl)
return sctx->get_name(); return sctx->get_name();
} }
static void llama_sampler_greedy_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_greedy *) smpl->ctx;
GGML_UNUSED(ctx);
}
static struct llama_sampler * llama_sampler_greedy_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_greedy *) smpl->ctx;
auto * result = llama_sampler_init_greedy();
// copy the state
{
auto * result_ctx = (llama_sampler_greedy *) result->ctx;
GGML_UNUSED(ctx);
GGML_UNUSED(result_ctx);
}
return result;
}
static void llama_sampler_greedy_free(struct llama_sampler * smpl) {
delete (llama_sampler_greedy *) smpl->ctx;
}
static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) { static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
cur_p->selected = 0; cur_p->selected = 0;
for (size_t i = 1; i < cur_p->size; ++i) { for (size_t i = 1; i < cur_p->size; ++i) {
@ -959,9 +983,9 @@ static struct llama_sampler_i llama_sampler_greedy_i = {
/* .name = */ llama_sampler_greedy_name, /* .name = */ llama_sampler_greedy_name,
/* .accept = */ nullptr, /* .accept = */ nullptr,
/* .apply = */ llama_sampler_greedy_apply, /* .apply = */ llama_sampler_greedy_apply,
/* .reset = */ nullptr, /* .reset = */ llama_sampler_greedy_reset,
/* .clone = */ nullptr, /* .clone = */ llama_sampler_greedy_clone,
/* .free = */ nullptr, /* .free = */ llama_sampler_greedy_free,
/* .backend_init = */ llama_sampler_greedy_backend_init, /* .backend_init = */ llama_sampler_greedy_backend_init,
/* .backend_accept = */ nullptr, /* .backend_accept = */ nullptr,
/* .backend_apply = */ llama_sampler_greedy_backend_apply, /* .backend_apply = */ llama_sampler_greedy_backend_apply,
@ -1069,6 +1093,12 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da
#endif #endif
} }
static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_dist *) smpl->ctx;
ctx->seed_cur = get_rng_seed(ctx->seed);
ctx->rng.seed(ctx->seed_cur);
}
static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) { static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_dist *) smpl->ctx; const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
auto * result = llama_sampler_init_dist(ctx->seed); auto * result = llama_sampler_init_dist(ctx->seed);
@ -1083,12 +1113,6 @@ static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sample
return result; return result;
} }
static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_dist *) smpl->ctx;
ctx->seed_cur = get_rng_seed(ctx->seed);
ctx->rng.seed(ctx->seed_cur);
}
static void llama_sampler_dist_free(struct llama_sampler * smpl) { static void llama_sampler_dist_free(struct llama_sampler * smpl) {
delete (llama_sampler_dist *) smpl->ctx; delete (llama_sampler_dist *) smpl->ctx;
} }

View File

@ -314,6 +314,8 @@ static void test_backend_greedy_sampling(const char * model_path) {
GGML_ASSERT(false && "Failed to decode token"); GGML_ASSERT(false && "Failed to decode token");
} }
} }
llama_sampler_free(backend_sampler_chain);
} }
static void test_backend_top_k_sampling(const char * model_path) { static void test_backend_top_k_sampling(const char * model_path) {
@ -349,6 +351,8 @@ static void test_backend_top_k_sampling(const char * model_path) {
test_ctx.token_to_piece(candidates[i], false).c_str()); test_ctx.token_to_piece(candidates[i], false).c_str());
} }
llama_sampler_free(backend_sampler_chain);
// Sample using CPU sampler for verification that it is possible to do hybrid // Sample using CPU sampler for verification that it is possible to do hybrid
// sampling, first top_k on the backend and then dist on the CPU. // sampling, first top_k on the backend and then dist on the CPU.
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
@ -392,6 +396,9 @@ static void test_backend_temp_sampling(const char * model_path) {
GGML_ASSERT(false && "Failed to decode token"); GGML_ASSERT(false && "Failed to decode token");
} }
llama_sampler_free(backend_sampler_chain_0);
llama_sampler_free(backend_sampler_chain_1);
// Verfify sequence 0 // Verfify sequence 0
{ {
int32_t batch_idx = test_ctx.idx_for_seq(0); int32_t batch_idx = test_ctx.idx_for_seq(0);
@ -457,6 +464,8 @@ static void test_backend_temp_sampling(const char * model_path) {
uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
GGML_ASSERT(n_logits == 1); GGML_ASSERT(n_logits == 1);
llama_sampler_free(backend_sampler_chain);
}; };
test_argmax_temp(0.0f); test_argmax_temp(0.0f);
@ -496,6 +505,8 @@ static void test_backend_temp_ext_sampling(const char * model_path) {
int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
GGML_ASSERT(n_logits == test_ctx.n_vocab); GGML_ASSERT(n_logits == test_ctx.n_vocab);
} }
llama_sampler_free(backend_sampler_chain);
} }
test_ctx.reset(); test_ctx.reset();
@ -532,6 +543,8 @@ static void test_backend_temp_ext_sampling(const char * model_path) {
} else { } else {
GGML_ASSERT(n_logits == (uint32_t) test_ctx.n_vocab); GGML_ASSERT(n_logits == (uint32_t) test_ctx.n_vocab);
} }
llama_sampler_free(backend_sampler_chain);
}; };
test_argmax_temp(0.0f, 0.3f, 1.0f); // Greedy (temp=0) test_argmax_temp(0.0f, 0.3f, 1.0f); // Greedy (temp=0)
@ -597,6 +610,7 @@ static void test_backend_min_p_sampling(const char * model_path) {
printf("min-p sampling test PASSED\n"); printf("min-p sampling test PASSED\n");
llama_sampler_free(backend_sampler_chain);
llama_sampler_free(chain); llama_sampler_free(chain);
} }
@ -653,6 +667,7 @@ static void test_backend_top_p_sampling(const char * model_path) {
printf("top-p sampling test PASSED\n"); printf("top-p sampling test PASSED\n");
llama_sampler_free(backend_sampler_chain);
llama_sampler_free(chain); llama_sampler_free(chain);
} }
@ -723,6 +738,9 @@ static void test_backend_multi_sequence_sampling(const char * model_path) {
} }
} }
llama_sampler_free(sampler_chain_0);
llama_sampler_free(sampler_chain_1);
printf("backend multi-sequence sampling test PASSED\n"); printf("backend multi-sequence sampling test PASSED\n");
} }
@ -755,6 +773,7 @@ static void test_backend_dist_sampling(const char * model_path) {
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
llama_sampler_free(backend_sampler_chain); llama_sampler_free(backend_sampler_chain);
printf("backend dist sampling test PASSED\n"); printf("backend dist sampling test PASSED\n");
} }