common : simplify sampler chain initialization
This commit is contained in:
parent
217469f07f
commit
4032ce2378
|
|
@ -212,18 +212,12 @@ struct common_params_sampling {
|
|||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
|
||||
|
||||
bool backend_sampling = false; // enable backend sampling
|
||||
bool backend_sampling = false;
|
||||
|
||||
bool has_logit_bias() const {
|
||||
return !logit_bias.empty();
|
||||
}
|
||||
|
||||
bool is_disabled(enum common_sampler_type type) const;
|
||||
|
||||
// remove disabled samplers
|
||||
// TODO: temporary until all samplers have llama_sampler_backend_ API [LLAMA_SAMPLER_BACKEND]
|
||||
void filter_disabled();
|
||||
|
||||
// print the parameters into a string
|
||||
std::string print() const;
|
||||
};
|
||||
|
|
@ -661,7 +655,7 @@ std::vector<common_file_info> fs_list_files(const std::string & path);
|
|||
|
||||
struct common_sampler;
|
||||
|
||||
// note: defines object's lifetime
|
||||
// note: defines the model, context, samplers, ets. lifetimes
|
||||
struct common_init_result {
|
||||
common_init_result(common_params & params);
|
||||
~common_init_result();
|
||||
|
|
|
|||
|
|
@ -163,84 +163,6 @@ struct common_sampler {
|
|||
mutable int64_t t_total_us = 0;
|
||||
};
|
||||
|
||||
// TODO: temporary until all samplers have llama_sampler_backend_ API [LLAMA_SAMPLER_BACKEND]
|
||||
static bool common_sampler_type_has_backend_support(enum common_sampler_type type) {
|
||||
switch (type) {
|
||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||
case COMMON_SAMPLER_TYPE_TOP_P:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool common_params_sampling::is_disabled(enum common_sampler_type type) const {
|
||||
switch (type) {
|
||||
case COMMON_SAMPLER_TYPE_PENALTIES:
|
||||
if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_DRY:
|
||||
if (dry_multiplier == 0.0f || dry_base < 1.0f || dry_penalty_last_n == 0) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
||||
if (typ_p >= 1.0) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
|
||||
if (top_n_sigma <= 0.0) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
if (top_k <= 0) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||
if (dynatemp_range <= 0.0f) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||
if (min_p <= 0.0f) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_P:
|
||||
if (top_p >= 1.0f) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_XTC:
|
||||
if (xtc_probability <= 0.0f || xtc_threshold == 0.50f) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void common_params_sampling::filter_disabled() {
|
||||
for (auto it = samplers.begin(); it != samplers.end();) {
|
||||
if (is_disabled(*it)) {
|
||||
LOG_WRN("%s: removing disabled sampler %s\n", __func__, common_sampler_type_to_str(*it).c_str());
|
||||
it = samplers.erase(it);
|
||||
} else {
|
||||
++it;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string common_params_sampling::print() const {
|
||||
char result[1024];
|
||||
|
||||
|
|
@ -257,7 +179,7 @@ std::string common_params_sampling::print() const {
|
|||
return std::string(result);
|
||||
}
|
||||
|
||||
struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params) {
|
||||
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
|
||||
|
|
@ -324,11 +246,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
|||
}
|
||||
}
|
||||
|
||||
// TODO: temporary until all samplers have llama_sampler_backend_ API [LLAMA_SAMPLER_BACKEND]
|
||||
if (params.backend_sampling) {
|
||||
params.filter_disabled();
|
||||
}
|
||||
|
||||
auto * result = new common_sampler {
|
||||
/* .params = */ params,
|
||||
/* .grmr = */ grmr,
|
||||
|
|
@ -339,54 +256,13 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
|||
/* .cur_p = */ {},
|
||||
};
|
||||
|
||||
size_t idx_smpl = 0;
|
||||
|
||||
bool is_backend = true;
|
||||
|
||||
is_backend = is_backend && params.backend_sampling;
|
||||
is_backend = is_backend && (params.samplers.size() == 0 || common_sampler_type_has_backend_support(params.samplers[idx_smpl]));
|
||||
|
||||
std::vector<llama_sampler *> samplers;
|
||||
if (params.has_logit_bias()) {
|
||||
llama_sampler_chain_add(is_backend ? result->chain_backend : result->chain,
|
||||
llama_sampler_init_logit_bias(
|
||||
llama_vocab_n_tokens(vocab),
|
||||
params.logit_bias.size(),
|
||||
params.logit_bias.data()));
|
||||
samplers.push_back(llama_sampler_init_logit_bias(llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data()));
|
||||
}
|
||||
|
||||
if (params.mirostat == 0) {
|
||||
// backend samplers are added first
|
||||
while (is_backend && idx_smpl < params.samplers.size()) {
|
||||
const auto & cnstr = params.samplers[idx_smpl++];
|
||||
|
||||
if (!common_sampler_type_has_backend_support(cnstr)) {
|
||||
is_backend = false;
|
||||
--idx_smpl;
|
||||
break;
|
||||
}
|
||||
|
||||
switch (cnstr) {
|
||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
llama_sampler_chain_add(result->chain_backend, llama_sampler_init_top_k(params.top_k));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||
llama_sampler_chain_add(result->chain_backend, llama_sampler_init_temp(params.temp));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||
llama_sampler_chain_add(result->chain_backend, llama_sampler_init_min_p(params.min_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_P:
|
||||
llama_sampler_chain_add(result->chain_backend, llama_sampler_init_top_p(params.top_p, params.min_keep));
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unsupported backend sampler");
|
||||
}
|
||||
}
|
||||
|
||||
// Add remaining CPU samplers
|
||||
while (idx_smpl < params.samplers.size()) {
|
||||
const auto & cnstr = params.samplers[idx_smpl++];
|
||||
|
||||
for (const auto & cnstr : params.samplers) {
|
||||
switch (cnstr) {
|
||||
case COMMON_SAMPLER_TYPE_DRY:
|
||||
{
|
||||
|
|
@ -396,52 +272,63 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
|||
c_breakers.push_back(str.c_str());
|
||||
}
|
||||
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
||||
samplers.push_back(llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
||||
samplers.push_back(llama_sampler_init_top_k (params.top_k));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
|
||||
samplers.push_back(llama_sampler_init_top_p (params.top_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
|
||||
samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
|
||||
samplers.push_back(llama_sampler_init_min_p (params.min_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_XTC:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
||||
samplers.push_back(llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
|
||||
samplers.push_back(llama_sampler_init_typical (params.typ_p, params.min_keep));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||
samplers.push_back(llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_INFILL:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
|
||||
samplers.push_back(llama_sampler_init_infill (vocab));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_PENALTIES:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
||||
samplers.push_back(llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unknown sampler type");
|
||||
}
|
||||
}
|
||||
|
||||
llama_sampler_chain_add(is_backend ? result->chain_backend : result->chain, llama_sampler_init_dist(params.seed));
|
||||
samplers.push_back(llama_sampler_init_dist(params.seed));
|
||||
} else if (params.mirostat == 1) {
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
|
||||
samplers.push_back(llama_sampler_init_temp(params.temp));
|
||||
samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
|
||||
} else if (params.mirostat == 2) {
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
|
||||
samplers.push_back(llama_sampler_init_temp(params.temp));
|
||||
samplers.push_back(llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
|
||||
} else {
|
||||
GGML_ASSERT(false && "unknown mirostat version");
|
||||
}
|
||||
|
||||
bool is_backend = params.backend_sampling;
|
||||
|
||||
// split in two chains: backend -> CPU
|
||||
for (auto * smpl : samplers) {
|
||||
if (!smpl->iface->backend_apply) {
|
||||
is_backend = false;
|
||||
}
|
||||
|
||||
llama_sampler_chain_add(is_backend ? result->chain_backend : result->chain, smpl);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -36,8 +36,7 @@ struct common_sampler;
|
|||
|
||||
// llama_sampler API overloads
|
||||
|
||||
// TODO: params should become const again [LLAMA_SAMPLER_BACKEND]
|
||||
struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params);
|
||||
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params);
|
||||
|
||||
void common_sampler_free(struct common_sampler * gsmpl);
|
||||
|
||||
|
|
|
|||
|
|
@ -73,16 +73,10 @@ int main(int argc, char ** argv) {
|
|||
for (int32_t i = 0; i < n_parallel; ++i) {
|
||||
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
||||
|
||||
if (params.sampling.backend_sampling) {
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed));
|
||||
} else {
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed));
|
||||
}
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed));
|
||||
|
||||
sampler_configs.push_back({ i, smpl });
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1212,7 +1212,7 @@ extern "C" {
|
|||
};
|
||||
|
||||
struct llama_sampler {
|
||||
const struct llama_sampler_i * iface;
|
||||
struct llama_sampler_i * iface;
|
||||
|
||||
llama_sampler_context_t ctx;
|
||||
};
|
||||
|
|
@ -1220,7 +1220,7 @@ extern "C" {
|
|||
LLAMA_API bool llama_set_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl);
|
||||
|
||||
// mirror of llama_sampler_i:
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx);
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init ( struct llama_sampler_i * iface, llama_sampler_context_t ctx);
|
||||
LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
|
||||
LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
|
||||
LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);
|
||||
|
|
|
|||
|
|
@ -2102,7 +2102,7 @@ void llm_graph_context::build_sampling() const {
|
|||
ggml_build_forward_expand(gf, data.sampled);
|
||||
}
|
||||
|
||||
if (data.probs != nullptr) {
|
||||
if (data.probs != nullptr) {
|
||||
res->t_sampled_probs[seq_id] = data.probs;
|
||||
ggml_build_forward_expand(gf, data.probs);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -349,7 +349,7 @@ static uint32_t get_rng_seed(uint32_t seed) {
|
|||
// llama_sampler API
|
||||
|
||||
struct llama_sampler * llama_sampler_init(
|
||||
const struct llama_sampler_i * iface,
|
||||
struct llama_sampler_i * iface,
|
||||
llama_sampler_context_t ctx) {
|
||||
return new llama_sampler {
|
||||
/* .iface = */ iface,
|
||||
|
|
@ -468,6 +468,42 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte
|
|||
return token;
|
||||
}
|
||||
|
||||
// backend sampling (empty iface)
|
||||
|
||||
static void llama_sampler_empty_backend_init(
|
||||
struct llama_sampler * smpl,
|
||||
ggml_backend_buffer_type_t buft) {
|
||||
GGML_UNUSED(smpl);
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static void llama_sampler_empty_backend_accept(
|
||||
struct llama_sampler * smpl,
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf,
|
||||
struct ggml_tensor * selected_token) {
|
||||
GGML_UNUSED(smpl);
|
||||
GGML_UNUSED(ctx);
|
||||
GGML_UNUSED(gf);
|
||||
GGML_UNUSED(selected_token);
|
||||
}
|
||||
|
||||
static void llama_sampler_empty_backend_apply(
|
||||
struct llama_sampler * smpl,
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * gf,
|
||||
struct llama_sampler_data * data) {
|
||||
GGML_UNUSED(smpl);
|
||||
GGML_UNUSED(ctx);
|
||||
GGML_UNUSED(gf);
|
||||
GGML_UNUSED(data);
|
||||
}
|
||||
|
||||
static void llama_sampler_empty_backend_set_input(struct llama_sampler * smpl) {
|
||||
GGML_UNUSED(smpl);
|
||||
}
|
||||
|
||||
|
||||
// sampler chain
|
||||
|
||||
static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
|
||||
|
|
@ -1171,7 +1207,7 @@ static void llama_sampler_top_p_backend_apply(
|
|||
|
||||
ggml_set_output(data->candidates);
|
||||
ggml_build_forward_expand(gf, data->candidates);
|
||||
|
||||
|
||||
ggml_set_output(data->logits);
|
||||
ggml_build_forward_expand(gf, data->logits);
|
||||
}
|
||||
|
|
@ -1446,13 +1482,24 @@ static struct llama_sampler_i llama_sampler_typical_i = {
|
|||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
|
||||
return llama_sampler_init(
|
||||
auto * res = llama_sampler_init(
|
||||
/* .iface = */ &llama_sampler_typical_i,
|
||||
/* .ctx = */ new llama_sampler_typical {
|
||||
/* .p = */ p,
|
||||
/* .min_keep = */ min_keep,
|
||||
}
|
||||
);
|
||||
|
||||
const bool is_empty = (p >= 1.0f);
|
||||
|
||||
if (is_empty) {
|
||||
res->iface->backend_init = llama_sampler_empty_backend_init;
|
||||
res->iface->backend_accept = llama_sampler_empty_backend_accept;
|
||||
res->iface->backend_apply = llama_sampler_empty_backend_apply;
|
||||
res->iface->backend_set_input = llama_sampler_empty_backend_set_input;
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
// temp
|
||||
|
|
@ -1615,6 +1662,27 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
|
|||
delete (llama_sampler_temp_ext *) smpl->ctx;
|
||||
}
|
||||
|
||||
static void llama_sampler_temp_ext_backend_apply(
|
||||
struct llama_sampler * smpl,
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * gf,
|
||||
struct llama_sampler_data * data) {
|
||||
auto * ctx_data = (llama_sampler_temp *) smpl->ctx;
|
||||
|
||||
if (ctx_data->temp <= 0.0f) {
|
||||
return;
|
||||
}
|
||||
|
||||
struct ggml_tensor * scaled = ggml_scale(ctx, data->logits, 1.0f / ctx_data->temp);
|
||||
ggml_set_name(scaled, "temp_scaled");
|
||||
|
||||
// Make sure the scaled tensor is contiguous for subsequent operations
|
||||
data->logits = ggml_cont(ctx, scaled);
|
||||
ggml_set_name(data->logits, "temp_scaled_logits");
|
||||
|
||||
ggml_build_forward_expand(gf, data->logits);
|
||||
}
|
||||
|
||||
static struct llama_sampler_i llama_sampler_temp_ext_i = {
|
||||
/* .name = */ llama_sampler_temp_ext_name,
|
||||
/* .accept = */ nullptr,
|
||||
|
|
@ -1629,7 +1697,7 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = {
|
|||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
|
||||
return llama_sampler_init(
|
||||
auto * res = llama_sampler_init(
|
||||
/* .iface = */ &llama_sampler_temp_ext_i,
|
||||
/* .ctx = */ new llama_sampler_temp_ext {
|
||||
/* .temp = */ temp,
|
||||
|
|
@ -1637,6 +1705,14 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
|
|||
/* .exponent = */ exponent,
|
||||
}
|
||||
);
|
||||
|
||||
const bool is_backend = delta <= 0.0f;
|
||||
|
||||
if (is_backend) {
|
||||
res->iface->backend_apply = llama_sampler_temp_ext_backend_apply;
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
// xtc
|
||||
|
|
@ -1727,8 +1803,9 @@ static struct llama_sampler_i llama_sampler_xtc_i = {
|
|||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
|
||||
auto seed_cur = get_rng_seed(seed);
|
||||
return llama_sampler_init(
|
||||
const auto seed_cur = get_rng_seed(seed);
|
||||
|
||||
auto * res = llama_sampler_init(
|
||||
/* .iface = */ &llama_sampler_xtc_i,
|
||||
/* .ctx = */ new llama_sampler_xtc {
|
||||
/* .probability = */ p,
|
||||
|
|
@ -1739,6 +1816,17 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep,
|
|||
/* .rng = */ std::mt19937(seed_cur),
|
||||
}
|
||||
);
|
||||
|
||||
const bool is_empty = (p <= 0.0f || t > 0.5f);
|
||||
|
||||
if (is_empty) {
|
||||
res->iface->backend_init = llama_sampler_empty_backend_init;
|
||||
res->iface->backend_accept = llama_sampler_empty_backend_accept;
|
||||
res->iface->backend_apply = llama_sampler_empty_backend_apply;
|
||||
res->iface->backend_set_input = llama_sampler_empty_backend_set_input;
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
// mirostat
|
||||
|
|
@ -2280,7 +2368,7 @@ struct llama_sampler * llama_sampler_init_penalties(
|
|||
float penalty_present) {
|
||||
penalty_last_n = std::max(penalty_last_n, 0);
|
||||
|
||||
return llama_sampler_init(
|
||||
auto * res = llama_sampler_init(
|
||||
/* .iface = */ &llama_sampler_penalties_i,
|
||||
/* .ctx = */ new llama_sampler_penalties {
|
||||
/* .penalty_last_n = */ penalty_last_n,
|
||||
|
|
@ -2291,6 +2379,17 @@ struct llama_sampler * llama_sampler_init_penalties(
|
|||
/* .token_count = */ {},
|
||||
}
|
||||
);
|
||||
|
||||
const bool is_empty = (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f));
|
||||
|
||||
if (is_empty) {
|
||||
res->iface->backend_init = llama_sampler_empty_backend_init;
|
||||
res->iface->backend_accept = llama_sampler_empty_backend_accept;
|
||||
res->iface->backend_apply = llama_sampler_empty_backend_apply;
|
||||
res->iface->backend_set_input = llama_sampler_empty_backend_set_input;
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
// top-n-sigma
|
||||
|
|
@ -2317,9 +2416,7 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t
|
|||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
// Only count non-negative infinity values
|
||||
if (cur_p->data[i].logit != -INFINITY) {
|
||||
if (cur_p->data[i].logit > max) {
|
||||
max = cur_p->data[i].logit;
|
||||
}
|
||||
max = std::max(max, cur_p->data[i].logit);
|
||||
logits_sum += cur_p->data[i].logit;
|
||||
valid_count++;
|
||||
}
|
||||
|
|
@ -2369,12 +2466,23 @@ static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
|
|||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
|
||||
return llama_sampler_init(
|
||||
auto * res = llama_sampler_init(
|
||||
/* .iface = */ &llama_sampler_top_n_sigma_i,
|
||||
/* .ctx = */ new llama_sampler_top_n_sigma {
|
||||
/* .n = */ n,
|
||||
}
|
||||
);
|
||||
|
||||
const bool is_empty = (n <= 0.0f);
|
||||
|
||||
if (is_empty) {
|
||||
res->iface->backend_init = llama_sampler_empty_backend_init;
|
||||
res->iface->backend_accept = llama_sampler_empty_backend_accept;
|
||||
res->iface->backend_apply = llama_sampler_empty_backend_apply;
|
||||
res->iface->backend_set_input = llama_sampler_empty_backend_set_input;
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
// DRY
|
||||
|
|
@ -2733,7 +2841,7 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
|
|||
}
|
||||
}
|
||||
|
||||
return llama_sampler_init(
|
||||
auto * res = llama_sampler_init(
|
||||
/* .iface = */ &llama_sampler_dry_i,
|
||||
/* .ctx = */ new llama_sampler_dry {
|
||||
/* .total_context_size = */ n_ctx_train,
|
||||
|
|
@ -2747,6 +2855,15 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
|
|||
/* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
|
||||
}
|
||||
);
|
||||
|
||||
if (!dry_enabled) {
|
||||
res->iface->backend_init = llama_sampler_empty_backend_init;
|
||||
res->iface->backend_accept = llama_sampler_empty_backend_accept;
|
||||
res->iface->backend_apply = llama_sampler_empty_backend_apply;
|
||||
res->iface->backend_set_input = llama_sampler_empty_backend_set_input;
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
// wrapper for test-sampling.cpp
|
||||
|
|
@ -2854,6 +2971,8 @@ static void llama_sampler_logit_bias_backend_apply(
|
|||
|
||||
// Add the sparse logit logit_bias to the logits
|
||||
struct ggml_tensor * logit_biased = ggml_add_inplace(ctx, data->logits, sctx->inp_logit_bias);
|
||||
data->logits = logit_biased;
|
||||
|
||||
ggml_build_forward_expand(gf, logit_biased);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue