sampling : fix greedy
This commit is contained in:
parent
8544aba37f
commit
74b112e3e7
|
|
@ -920,6 +920,30 @@ static const char * llama_sampler_greedy_name(const struct llama_sampler * smpl)
|
|||
return sctx->get_name();
|
||||
}
|
||||
|
||||
static void llama_sampler_greedy_reset(struct llama_sampler * smpl) {
|
||||
auto * ctx = (llama_sampler_greedy *) smpl->ctx;
|
||||
GGML_UNUSED(ctx);
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_greedy_clone(const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const llama_sampler_greedy *) smpl->ctx;
|
||||
auto * result = llama_sampler_init_greedy();
|
||||
|
||||
// copy the state
|
||||
{
|
||||
auto * result_ctx = (llama_sampler_greedy *) result->ctx;
|
||||
|
||||
GGML_UNUSED(ctx);
|
||||
GGML_UNUSED(result_ctx);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static void llama_sampler_greedy_free(struct llama_sampler * smpl) {
|
||||
delete (llama_sampler_greedy *) smpl->ctx;
|
||||
}
|
||||
|
||||
static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
|
||||
cur_p->selected = 0;
|
||||
for (size_t i = 1; i < cur_p->size; ++i) {
|
||||
|
|
@ -959,9 +983,9 @@ static struct llama_sampler_i llama_sampler_greedy_i = {
|
|||
/* .name = */ llama_sampler_greedy_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_greedy_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ nullptr,
|
||||
/* .free = */ nullptr,
|
||||
/* .reset = */ llama_sampler_greedy_reset,
|
||||
/* .clone = */ llama_sampler_greedy_clone,
|
||||
/* .free = */ llama_sampler_greedy_free,
|
||||
/* .backend_init = */ llama_sampler_greedy_backend_init,
|
||||
/* .backend_accept = */ nullptr,
|
||||
/* .backend_apply = */ llama_sampler_greedy_backend_apply,
|
||||
|
|
@ -1069,6 +1093,12 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da
|
|||
#endif
|
||||
}
|
||||
|
||||
static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
|
||||
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
||||
ctx->seed_cur = get_rng_seed(ctx->seed);
|
||||
ctx->rng.seed(ctx->seed_cur);
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
|
||||
const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
|
||||
auto * result = llama_sampler_init_dist(ctx->seed);
|
||||
|
|
@ -1083,12 +1113,6 @@ static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sample
|
|||
return result;
|
||||
}
|
||||
|
||||
static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
|
||||
auto * ctx = (llama_sampler_dist *) smpl->ctx;
|
||||
ctx->seed_cur = get_rng_seed(ctx->seed);
|
||||
ctx->rng.seed(ctx->seed_cur);
|
||||
}
|
||||
|
||||
static void llama_sampler_dist_free(struct llama_sampler * smpl) {
|
||||
delete (llama_sampler_dist *) smpl->ctx;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -314,6 +314,8 @@ static void test_backend_greedy_sampling(const char * model_path) {
|
|||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
}
|
||||
|
||||
llama_sampler_free(backend_sampler_chain);
|
||||
}
|
||||
|
||||
static void test_backend_top_k_sampling(const char * model_path) {
|
||||
|
|
@ -349,6 +351,8 @@ static void test_backend_top_k_sampling(const char * model_path) {
|
|||
test_ctx.token_to_piece(candidates[i], false).c_str());
|
||||
}
|
||||
|
||||
llama_sampler_free(backend_sampler_chain);
|
||||
|
||||
// Sample using CPU sampler for verification that it is possible to do hybrid
|
||||
// sampling, first top_k on the backend and then dist on the CPU.
|
||||
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
|
||||
|
|
@ -392,6 +396,9 @@ static void test_backend_temp_sampling(const char * model_path) {
|
|||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
|
||||
llama_sampler_free(backend_sampler_chain_0);
|
||||
llama_sampler_free(backend_sampler_chain_1);
|
||||
|
||||
// Verfify sequence 0
|
||||
{
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(0);
|
||||
|
|
@ -457,6 +464,8 @@ static void test_backend_temp_sampling(const char * model_path) {
|
|||
|
||||
uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
|
||||
GGML_ASSERT(n_logits == 1);
|
||||
|
||||
llama_sampler_free(backend_sampler_chain);
|
||||
};
|
||||
|
||||
test_argmax_temp(0.0f);
|
||||
|
|
@ -496,6 +505,8 @@ static void test_backend_temp_ext_sampling(const char * model_path) {
|
|||
int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
|
||||
GGML_ASSERT(n_logits == test_ctx.n_vocab);
|
||||
}
|
||||
|
||||
llama_sampler_free(backend_sampler_chain);
|
||||
}
|
||||
|
||||
test_ctx.reset();
|
||||
|
|
@ -532,6 +543,8 @@ static void test_backend_temp_ext_sampling(const char * model_path) {
|
|||
} else {
|
||||
GGML_ASSERT(n_logits == (uint32_t) test_ctx.n_vocab);
|
||||
}
|
||||
|
||||
llama_sampler_free(backend_sampler_chain);
|
||||
};
|
||||
|
||||
test_argmax_temp(0.0f, 0.3f, 1.0f); // Greedy (temp=0)
|
||||
|
|
@ -597,6 +610,7 @@ static void test_backend_min_p_sampling(const char * model_path) {
|
|||
|
||||
printf("min-p sampling test PASSED\n");
|
||||
|
||||
llama_sampler_free(backend_sampler_chain);
|
||||
llama_sampler_free(chain);
|
||||
}
|
||||
|
||||
|
|
@ -653,6 +667,7 @@ static void test_backend_top_p_sampling(const char * model_path) {
|
|||
|
||||
printf("top-p sampling test PASSED\n");
|
||||
|
||||
llama_sampler_free(backend_sampler_chain);
|
||||
llama_sampler_free(chain);
|
||||
}
|
||||
|
||||
|
|
@ -723,6 +738,9 @@ static void test_backend_multi_sequence_sampling(const char * model_path) {
|
|||
}
|
||||
}
|
||||
|
||||
llama_sampler_free(sampler_chain_0);
|
||||
llama_sampler_free(sampler_chain_1);
|
||||
|
||||
printf("backend multi-sequence sampling test PASSED\n");
|
||||
}
|
||||
|
||||
|
|
@ -755,6 +773,7 @@ static void test_backend_dist_sampling(const char * model_path) {
|
|||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
||||
llama_sampler_free(backend_sampler_chain);
|
||||
|
||||
printf("backend dist sampling test PASSED\n");
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue