sampling : do not create empty samplers
This commit is contained in:
parent
4032ce2378
commit
04f2822a86
|
|
@ -468,7 +468,41 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte
|
||||||
return token;
|
return token;
|
||||||
}
|
}
|
||||||
|
|
||||||
// backend sampling (empty iface)
|
// empty sampler
|
||||||
|
|
||||||
|
struct llama_sampler_empty {
|
||||||
|
const char * name;
|
||||||
|
};
|
||||||
|
|
||||||
|
static struct llama_sampler * llama_sampler_init_empty(const char * name);
|
||||||
|
|
||||||
|
static const char * llama_sampler_empty_name(const struct llama_sampler * smpl) {
|
||||||
|
auto * ctx = (llama_sampler_empty *) smpl->ctx;
|
||||||
|
return ctx->name;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void llama_sampler_empty_accept(struct llama_sampler * smpl, llama_token token) {
|
||||||
|
GGML_UNUSED(smpl);
|
||||||
|
GGML_UNUSED(token);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void llama_sampler_empty_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
|
GGML_UNUSED(smpl);
|
||||||
|
GGML_UNUSED(cur_p);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void llama_sampler_empty_reset(struct llama_sampler * smpl) {
|
||||||
|
GGML_UNUSED(smpl);
|
||||||
|
}
|
||||||
|
|
||||||
|
static struct llama_sampler * llama_sampler_empty_clone(const struct llama_sampler * smpl) {
|
||||||
|
auto * ctx = (llama_sampler_empty *) smpl->ctx;
|
||||||
|
return llama_sampler_init_empty(ctx->name);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void llama_sampler_empty_free(struct llama_sampler * smpl) {
|
||||||
|
delete (llama_sampler_empty *) smpl->ctx;
|
||||||
|
}
|
||||||
|
|
||||||
static void llama_sampler_empty_backend_init(
|
static void llama_sampler_empty_backend_init(
|
||||||
struct llama_sampler * smpl,
|
struct llama_sampler * smpl,
|
||||||
|
|
@ -503,6 +537,27 @@ static void llama_sampler_empty_backend_set_input(struct llama_sampler * smpl) {
|
||||||
GGML_UNUSED(smpl);
|
GGML_UNUSED(smpl);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static struct llama_sampler_i llama_sampler_empty_i = {
|
||||||
|
/* .name = */ llama_sampler_empty_name,
|
||||||
|
/* .accept = */ llama_sampler_empty_accept,
|
||||||
|
/* .apply = */ llama_sampler_empty_apply,
|
||||||
|
/* .reset = */ llama_sampler_empty_reset,
|
||||||
|
/* .clone = */ llama_sampler_empty_clone,
|
||||||
|
/* .free = */ llama_sampler_empty_free,
|
||||||
|
/* .backend_init = */ llama_sampler_empty_backend_init,
|
||||||
|
/* .backend_accept = */ llama_sampler_empty_backend_accept,
|
||||||
|
/* .backend_apply = */ llama_sampler_empty_backend_apply,
|
||||||
|
/* .backend_set_input = */ llama_sampler_empty_backend_set_input,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llama_sampler * llama_sampler_init_empty(const char * name) {
|
||||||
|
return llama_sampler_init(
|
||||||
|
/* .iface = */ &llama_sampler_empty_i,
|
||||||
|
/* .ctx = */ new llama_sampler_empty {
|
||||||
|
/* .name = */ name,
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
// sampler chain
|
// sampler chain
|
||||||
|
|
||||||
|
|
@ -1040,6 +1095,12 @@ static struct llama_sampler_i llama_sampler_top_k_i = {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
|
struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
|
||||||
|
const bool is_empty = (k <= 0);
|
||||||
|
|
||||||
|
if (is_empty) {
|
||||||
|
return llama_sampler_init_empty("top-k?");
|
||||||
|
}
|
||||||
|
|
||||||
return llama_sampler_init(
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_top_k_i,
|
/* .iface = */ &llama_sampler_top_k_i,
|
||||||
/* .ctx = */ new llama_sampler_top_k {
|
/* .ctx = */ new llama_sampler_top_k {
|
||||||
|
|
@ -1226,6 +1287,12 @@ static struct llama_sampler_i llama_sampler_top_p_i = {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
|
struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
|
||||||
|
const bool is_empty = (p <= 0.0f);
|
||||||
|
|
||||||
|
if (is_empty) {
|
||||||
|
return llama_sampler_init_empty("top-p?");
|
||||||
|
}
|
||||||
|
|
||||||
return llama_sampler_init(
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_top_p_i,
|
/* .iface = */ &llama_sampler_top_p_i,
|
||||||
/* .ctx = */ new llama_sampler_top_p {
|
/* .ctx = */ new llama_sampler_top_p {
|
||||||
|
|
@ -1378,6 +1445,12 @@ static struct llama_sampler_i llama_sampler_min_p_i = {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
|
struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
|
||||||
|
const bool is_empty = (p <= 0.0f);
|
||||||
|
|
||||||
|
if (is_empty) {
|
||||||
|
return llama_sampler_init_empty("min-p?");
|
||||||
|
}
|
||||||
|
|
||||||
return llama_sampler_init(
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_min_p_i,
|
/* .iface = */ &llama_sampler_min_p_i,
|
||||||
/* .ctx = */ new llama_sampler_min_p {
|
/* .ctx = */ new llama_sampler_min_p {
|
||||||
|
|
@ -1482,24 +1555,19 @@ static struct llama_sampler_i llama_sampler_typical_i = {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
|
struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
|
||||||
auto * res = llama_sampler_init(
|
const bool is_empty = (p >= 1.0f);
|
||||||
|
|
||||||
|
if (is_empty) {
|
||||||
|
return llama_sampler_init_empty("typical?");
|
||||||
|
}
|
||||||
|
|
||||||
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_typical_i,
|
/* .iface = */ &llama_sampler_typical_i,
|
||||||
/* .ctx = */ new llama_sampler_typical {
|
/* .ctx = */ new llama_sampler_typical {
|
||||||
/* .p = */ p,
|
/* .p = */ p,
|
||||||
/* .min_keep = */ min_keep,
|
/* .min_keep = */ min_keep,
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
const bool is_empty = (p >= 1.0f);
|
|
||||||
|
|
||||||
if (is_empty) {
|
|
||||||
res->iface->backend_init = llama_sampler_empty_backend_init;
|
|
||||||
res->iface->backend_accept = llama_sampler_empty_backend_accept;
|
|
||||||
res->iface->backend_apply = llama_sampler_empty_backend_apply;
|
|
||||||
res->iface->backend_set_input = llama_sampler_empty_backend_set_input;
|
|
||||||
}
|
|
||||||
|
|
||||||
return res;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// temp
|
// temp
|
||||||
|
|
@ -1535,6 +1603,7 @@ static void llama_sampler_temp_backend_apply(
|
||||||
auto * ctx_data = (llama_sampler_temp *) smpl->ctx;
|
auto * ctx_data = (llama_sampler_temp *) smpl->ctx;
|
||||||
|
|
||||||
if (ctx_data->temp <= 0.0f) {
|
if (ctx_data->temp <= 0.0f) {
|
||||||
|
// TODO: this is incorrect - find the most probable token instead
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1562,6 +1631,12 @@ static struct llama_sampler_i llama_sampler_temp_i = {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_temp(float temp) {
|
struct llama_sampler * llama_sampler_init_temp(float temp) {
|
||||||
|
const bool is_empty = temp == 1.0f;
|
||||||
|
|
||||||
|
if (is_empty) {
|
||||||
|
return llama_sampler_init_empty("temp?");
|
||||||
|
}
|
||||||
|
|
||||||
return llama_sampler_init(
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_temp_i,
|
/* .iface = */ &llama_sampler_temp_i,
|
||||||
/* .ctx = */ new llama_sampler_temp {
|
/* .ctx = */ new llama_sampler_temp {
|
||||||
|
|
@ -1662,14 +1737,19 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
|
||||||
delete (llama_sampler_temp_ext *) smpl->ctx;
|
delete (llama_sampler_temp_ext *) smpl->ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: deduplicate with llama_sampler_temp_backend_apply
|
||||||
static void llama_sampler_temp_ext_backend_apply(
|
static void llama_sampler_temp_ext_backend_apply(
|
||||||
struct llama_sampler * smpl,
|
struct llama_sampler * smpl,
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_cgraph * gf,
|
struct ggml_cgraph * gf,
|
||||||
struct llama_sampler_data * data) {
|
struct llama_sampler_data * data) {
|
||||||
auto * ctx_data = (llama_sampler_temp *) smpl->ctx;
|
auto * ctx_data = (llama_sampler_temp_ext *) smpl->ctx;
|
||||||
|
|
||||||
|
// TODO: implement
|
||||||
|
GGML_ASSERT(ctx_data->delta <= 0.0f && "not implemented");
|
||||||
|
|
||||||
if (ctx_data->temp <= 0.0f) {
|
if (ctx_data->temp <= 0.0f) {
|
||||||
|
// TODO: this is incorrect - find the most probable token instead
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1697,6 +1777,12 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
|
struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
|
||||||
|
const bool is_empty = temp == 1.0f && delta <= 0.0f;
|
||||||
|
|
||||||
|
if (is_empty) {
|
||||||
|
return llama_sampler_init_empty("temp-ext?");
|
||||||
|
}
|
||||||
|
|
||||||
auto * res = llama_sampler_init(
|
auto * res = llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_temp_ext_i,
|
/* .iface = */ &llama_sampler_temp_ext_i,
|
||||||
/* .ctx = */ new llama_sampler_temp_ext {
|
/* .ctx = */ new llama_sampler_temp_ext {
|
||||||
|
|
@ -1803,9 +1889,15 @@ static struct llama_sampler_i llama_sampler_xtc_i = {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
|
struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
|
||||||
|
const bool is_empty = (p <= 0.0f || t > 0.5f);
|
||||||
|
|
||||||
|
if (is_empty) {
|
||||||
|
return llama_sampler_init_empty("xtc?");
|
||||||
|
}
|
||||||
|
|
||||||
const auto seed_cur = get_rng_seed(seed);
|
const auto seed_cur = get_rng_seed(seed);
|
||||||
|
|
||||||
auto * res = llama_sampler_init(
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_xtc_i,
|
/* .iface = */ &llama_sampler_xtc_i,
|
||||||
/* .ctx = */ new llama_sampler_xtc {
|
/* .ctx = */ new llama_sampler_xtc {
|
||||||
/* .probability = */ p,
|
/* .probability = */ p,
|
||||||
|
|
@ -1816,17 +1908,6 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep,
|
||||||
/* .rng = */ std::mt19937(seed_cur),
|
/* .rng = */ std::mt19937(seed_cur),
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
const bool is_empty = (p <= 0.0f || t > 0.5f);
|
|
||||||
|
|
||||||
if (is_empty) {
|
|
||||||
res->iface->backend_init = llama_sampler_empty_backend_init;
|
|
||||||
res->iface->backend_accept = llama_sampler_empty_backend_accept;
|
|
||||||
res->iface->backend_apply = llama_sampler_empty_backend_apply;
|
|
||||||
res->iface->backend_set_input = llama_sampler_empty_backend_set_input;
|
|
||||||
}
|
|
||||||
|
|
||||||
return res;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// mirostat
|
// mirostat
|
||||||
|
|
@ -1927,7 +2008,8 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
|
struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
|
||||||
auto seed_cur = get_rng_seed(seed);
|
const auto seed_cur = get_rng_seed(seed);
|
||||||
|
|
||||||
return llama_sampler_init(
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_mirostat_i,
|
/* .iface = */ &llama_sampler_mirostat_i,
|
||||||
/* .ctx = */ new llama_sampler_mirostat {
|
/* .ctx = */ new llama_sampler_mirostat {
|
||||||
|
|
@ -2368,7 +2450,13 @@ struct llama_sampler * llama_sampler_init_penalties(
|
||||||
float penalty_present) {
|
float penalty_present) {
|
||||||
penalty_last_n = std::max(penalty_last_n, 0);
|
penalty_last_n = std::max(penalty_last_n, 0);
|
||||||
|
|
||||||
auto * res = llama_sampler_init(
|
const bool is_empty = (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f));
|
||||||
|
|
||||||
|
if (is_empty) {
|
||||||
|
return llama_sampler_init_empty("penalties?");
|
||||||
|
}
|
||||||
|
|
||||||
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_penalties_i,
|
/* .iface = */ &llama_sampler_penalties_i,
|
||||||
/* .ctx = */ new llama_sampler_penalties {
|
/* .ctx = */ new llama_sampler_penalties {
|
||||||
/* .penalty_last_n = */ penalty_last_n,
|
/* .penalty_last_n = */ penalty_last_n,
|
||||||
|
|
@ -2379,17 +2467,6 @@ struct llama_sampler * llama_sampler_init_penalties(
|
||||||
/* .token_count = */ {},
|
/* .token_count = */ {},
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
const bool is_empty = (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f));
|
|
||||||
|
|
||||||
if (is_empty) {
|
|
||||||
res->iface->backend_init = llama_sampler_empty_backend_init;
|
|
||||||
res->iface->backend_accept = llama_sampler_empty_backend_accept;
|
|
||||||
res->iface->backend_apply = llama_sampler_empty_backend_apply;
|
|
||||||
res->iface->backend_set_input = llama_sampler_empty_backend_set_input;
|
|
||||||
}
|
|
||||||
|
|
||||||
return res;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// top-n-sigma
|
// top-n-sigma
|
||||||
|
|
@ -2466,23 +2543,18 @@ static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
|
struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
|
||||||
auto * res = llama_sampler_init(
|
const bool is_empty = (n <= 0.0f);
|
||||||
|
|
||||||
|
if (is_empty) {
|
||||||
|
return llama_sampler_init_empty("top-n-sigma?");
|
||||||
|
}
|
||||||
|
|
||||||
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_top_n_sigma_i,
|
/* .iface = */ &llama_sampler_top_n_sigma_i,
|
||||||
/* .ctx = */ new llama_sampler_top_n_sigma {
|
/* .ctx = */ new llama_sampler_top_n_sigma {
|
||||||
/* .n = */ n,
|
/* .n = */ n,
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
const bool is_empty = (n <= 0.0f);
|
|
||||||
|
|
||||||
if (is_empty) {
|
|
||||||
res->iface->backend_init = llama_sampler_empty_backend_init;
|
|
||||||
res->iface->backend_accept = llama_sampler_empty_backend_accept;
|
|
||||||
res->iface->backend_apply = llama_sampler_empty_backend_apply;
|
|
||||||
res->iface->backend_set_input = llama_sampler_empty_backend_set_input;
|
|
||||||
}
|
|
||||||
|
|
||||||
return res;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DRY
|
// DRY
|
||||||
|
|
@ -2818,6 +2890,10 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
|
||||||
|
|
||||||
const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);
|
const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);
|
||||||
|
|
||||||
|
if (!dry_enabled) {
|
||||||
|
return llama_sampler_init_empty("dry?");
|
||||||
|
}
|
||||||
|
|
||||||
if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
|
if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
|
||||||
// Process sequence breakers
|
// Process sequence breakers
|
||||||
for (size_t i = 0; i < num_breakers; ++i) {
|
for (size_t i = 0; i < num_breakers; ++i) {
|
||||||
|
|
@ -2841,7 +2917,7 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto * res = llama_sampler_init(
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_dry_i,
|
/* .iface = */ &llama_sampler_dry_i,
|
||||||
/* .ctx = */ new llama_sampler_dry {
|
/* .ctx = */ new llama_sampler_dry {
|
||||||
/* .total_context_size = */ n_ctx_train,
|
/* .total_context_size = */ n_ctx_train,
|
||||||
|
|
@ -2855,15 +2931,6 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
|
||||||
/* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
|
/* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
if (!dry_enabled) {
|
|
||||||
res->iface->backend_init = llama_sampler_empty_backend_init;
|
|
||||||
res->iface->backend_accept = llama_sampler_empty_backend_accept;
|
|
||||||
res->iface->backend_apply = llama_sampler_empty_backend_apply;
|
|
||||||
res->iface->backend_set_input = llama_sampler_empty_backend_set_input;
|
|
||||||
}
|
|
||||||
|
|
||||||
return res;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// wrapper for test-sampling.cpp
|
// wrapper for test-sampling.cpp
|
||||||
|
|
@ -3035,6 +3102,12 @@ struct llama_sampler * llama_sampler_init_logit_bias(
|
||||||
int32_t n_vocab,
|
int32_t n_vocab,
|
||||||
int32_t n_logit_bias,
|
int32_t n_logit_bias,
|
||||||
const llama_logit_bias * logit_bias) {
|
const llama_logit_bias * logit_bias) {
|
||||||
|
const bool is_empty = n_logit_bias <= 0;
|
||||||
|
|
||||||
|
if (is_empty) {
|
||||||
|
return llama_sampler_init_empty("logit-bias?");
|
||||||
|
}
|
||||||
|
|
||||||
return llama_sampler_init(
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_logit_bias_i,
|
/* .iface = */ &llama_sampler_logit_bias_i,
|
||||||
/* .ctx = */ new llama_sampler_logit_bias {
|
/* .ctx = */ new llama_sampler_logit_bias {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue