common : simplify sampler chain initialization
This commit is contained in:
parent
217469f07f
commit
4032ce2378
|
|
@ -212,18 +212,12 @@ struct common_params_sampling {
|
||||||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||||
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
|
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
|
||||||
|
|
||||||
bool backend_sampling = false; // enable backend sampling
|
bool backend_sampling = false;
|
||||||
|
|
||||||
bool has_logit_bias() const {
|
bool has_logit_bias() const {
|
||||||
return !logit_bias.empty();
|
return !logit_bias.empty();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_disabled(enum common_sampler_type type) const;
|
|
||||||
|
|
||||||
// remove disabled samplers
|
|
||||||
// TODO: temporary until all samplers have llama_sampler_backend_ API [LLAMA_SAMPLER_BACKEND]
|
|
||||||
void filter_disabled();
|
|
||||||
|
|
||||||
// print the parameters into a string
|
// print the parameters into a string
|
||||||
std::string print() const;
|
std::string print() const;
|
||||||
};
|
};
|
||||||
|
|
@ -661,7 +655,7 @@ std::vector<common_file_info> fs_list_files(const std::string & path);
|
||||||
|
|
||||||
struct common_sampler;
|
struct common_sampler;
|
||||||
|
|
||||||
// note: defines object's lifetime
|
// note: defines the model, context, samplers, ets. lifetimes
|
||||||
struct common_init_result {
|
struct common_init_result {
|
||||||
common_init_result(common_params & params);
|
common_init_result(common_params & params);
|
||||||
~common_init_result();
|
~common_init_result();
|
||||||
|
|
|
||||||
|
|
@ -163,84 +163,6 @@ struct common_sampler {
|
||||||
mutable int64_t t_total_us = 0;
|
mutable int64_t t_total_us = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO: temporary until all samplers have llama_sampler_backend_ API [LLAMA_SAMPLER_BACKEND]
|
|
||||||
static bool common_sampler_type_has_backend_support(enum common_sampler_type type) {
|
|
||||||
switch (type) {
|
|
||||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
|
||||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
|
||||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
|
||||||
case COMMON_SAMPLER_TYPE_TOP_P:
|
|
||||||
return true;
|
|
||||||
default:
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bool common_params_sampling::is_disabled(enum common_sampler_type type) const {
|
|
||||||
switch (type) {
|
|
||||||
case COMMON_SAMPLER_TYPE_PENALTIES:
|
|
||||||
if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case COMMON_SAMPLER_TYPE_DRY:
|
|
||||||
if (dry_multiplier == 0.0f || dry_base < 1.0f || dry_penalty_last_n == 0) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
|
||||||
if (typ_p >= 1.0) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
|
|
||||||
if (top_n_sigma <= 0.0) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
|
||||||
if (top_k <= 0) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
|
||||||
if (dynatemp_range <= 0.0f) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
|
||||||
if (min_p <= 0.0f) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case COMMON_SAMPLER_TYPE_TOP_P:
|
|
||||||
if (top_p >= 1.0f) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
case COMMON_SAMPLER_TYPE_XTC:
|
|
||||||
if (xtc_probability <= 0.0f || xtc_threshold == 0.50f) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
void common_params_sampling::filter_disabled() {
|
|
||||||
for (auto it = samplers.begin(); it != samplers.end();) {
|
|
||||||
if (is_disabled(*it)) {
|
|
||||||
LOG_WRN("%s: removing disabled sampler %s\n", __func__, common_sampler_type_to_str(*it).c_str());
|
|
||||||
it = samplers.erase(it);
|
|
||||||
} else {
|
|
||||||
++it;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string common_params_sampling::print() const {
|
std::string common_params_sampling::print() const {
|
||||||
char result[1024];
|
char result[1024];
|
||||||
|
|
||||||
|
|
@ -257,7 +179,7 @@ std::string common_params_sampling::print() const {
|
||||||
return std::string(result);
|
return std::string(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params) {
|
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
|
||||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||||
|
|
||||||
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
|
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
|
||||||
|
|
@ -324,11 +246,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: temporary until all samplers have llama_sampler_backend_ API [LLAMA_SAMPLER_BACKEND]
|
|
||||||
if (params.backend_sampling) {
|
|
||||||
params.filter_disabled();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto * result = new common_sampler {
|
auto * result = new common_sampler {
|
||||||
/* .params = */ params,
|
/* .params = */ params,
|
||||||
/* .grmr = */ grmr,
|
/* .grmr = */ grmr,
|
||||||
|
|
@ -339,54 +256,13 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
||||||
/* .cur_p = */ {},
|
/* .cur_p = */ {},
|
||||||
};
|
};
|
||||||
|
|
||||||
size_t idx_smpl = 0;
|
std::vector<llama_sampler *> samplers;
|
||||||
|
|
||||||
bool is_backend = true;
|
|
||||||
|
|
||||||
is_backend = is_backend && params.backend_sampling;
|
|
||||||
is_backend = is_backend && (params.samplers.size() == 0 || common_sampler_type_has_backend_support(params.samplers[idx_smpl]));
|
|
||||||
|
|
||||||
if (params.has_logit_bias()) {
|
if (params.has_logit_bias()) {
|
||||||
llama_sampler_chain_add(is_backend ? result->chain_backend : result->chain,
|
samplers.push_back(llama_sampler_init_logit_bias(llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data()));
|
||||||
llama_sampler_init_logit_bias(
|
|
||||||
llama_vocab_n_tokens(vocab),
|
|
||||||
params.logit_bias.size(),
|
|
||||||
params.logit_bias.data()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.mirostat == 0) {
|
if (params.mirostat == 0) {
|
||||||
// backend samplers are added first
|
for (const auto & cnstr : params.samplers) {
|
||||||
while (is_backend && idx_smpl < params.samplers.size()) {
|
|
||||||
const auto & cnstr = params.samplers[idx_smpl++];
|
|
||||||
|
|
||||||
if (!common_sampler_type_has_backend_support(cnstr)) {
|
|
||||||
is_backend = false;
|
|
||||||
--idx_smpl;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
switch (cnstr) {
|
|
||||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
|
||||||
llama_sampler_chain_add(result->chain_backend, llama_sampler_init_top_k(params.top_k));
|
|
||||||
break;
|
|
||||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
|
||||||
llama_sampler_chain_add(result->chain_backend, llama_sampler_init_temp(params.temp));
|
|
||||||
break;
|
|
||||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
|
||||||
llama_sampler_chain_add(result->chain_backend, llama_sampler_init_min_p(params.min_p, params.min_keep));
|
|
||||||
break;
|
|
||||||
case COMMON_SAMPLER_TYPE_TOP_P:
|
|
||||||
llama_sampler_chain_add(result->chain_backend, llama_sampler_init_top_p(params.top_p, params.min_keep));
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
GGML_ASSERT(false && "unsupported backend sampler");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add remaining CPU samplers
|
|
||||||
while (idx_smpl < params.samplers.size()) {
|
|
||||||
const auto & cnstr = params.samplers[idx_smpl++];
|
|
||||||
|
|
||||||
switch (cnstr) {
|
switch (cnstr) {
|
||||||
case COMMON_SAMPLER_TYPE_DRY:
|
case COMMON_SAMPLER_TYPE_DRY:
|
||||||
{
|
{
|
||||||
|
|
@ -396,52 +272,63 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
||||||
c_breakers.push_back(str.c_str());
|
c_breakers.push_back(str.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
samplers.push_back(llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
samplers.push_back(llama_sampler_init_top_k (params.top_k));
|
||||||
break;
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_TOP_P:
|
case COMMON_SAMPLER_TYPE_TOP_P:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
|
samplers.push_back(llama_sampler_init_top_p (params.top_p, params.min_keep));
|
||||||
break;
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
|
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma));
|
samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma));
|
||||||
break;
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
|
samplers.push_back(llama_sampler_init_min_p (params.min_p, params.min_keep));
|
||||||
break;
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_XTC:
|
case COMMON_SAMPLER_TYPE_XTC:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
samplers.push_back(llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
||||||
break;
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
|
samplers.push_back(llama_sampler_init_typical (params.typ_p, params.min_keep));
|
||||||
break;
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
samplers.push_back(llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||||
break;
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_INFILL:
|
case COMMON_SAMPLER_TYPE_INFILL:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
|
samplers.push_back(llama_sampler_init_infill (vocab));
|
||||||
break;
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_PENALTIES:
|
case COMMON_SAMPLER_TYPE_PENALTIES:
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
samplers.push_back(llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false && "unknown sampler type");
|
GGML_ASSERT(false && "unknown sampler type");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sampler_chain_add(is_backend ? result->chain_backend : result->chain, llama_sampler_init_dist(params.seed));
|
samplers.push_back(llama_sampler_init_dist(params.seed));
|
||||||
} else if (params.mirostat == 1) {
|
} else if (params.mirostat == 1) {
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
samplers.push_back(llama_sampler_init_temp(params.temp));
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
|
samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
|
||||||
} else if (params.mirostat == 2) {
|
} else if (params.mirostat == 2) {
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
samplers.push_back(llama_sampler_init_temp(params.temp));
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
|
samplers.push_back(llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
|
||||||
} else {
|
} else {
|
||||||
GGML_ASSERT(false && "unknown mirostat version");
|
GGML_ASSERT(false && "unknown mirostat version");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool is_backend = params.backend_sampling;
|
||||||
|
|
||||||
|
// split in two chains: backend -> CPU
|
||||||
|
for (auto * smpl : samplers) {
|
||||||
|
if (!smpl->iface->backend_apply) {
|
||||||
|
is_backend = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_sampler_chain_add(is_backend ? result->chain_backend : result->chain, smpl);
|
||||||
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -36,8 +36,7 @@ struct common_sampler;
|
||||||
|
|
||||||
// llama_sampler API overloads
|
// llama_sampler API overloads
|
||||||
|
|
||||||
// TODO: params should become const again [LLAMA_SAMPLER_BACKEND]
|
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params);
|
||||||
struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params);
|
|
||||||
|
|
||||||
void common_sampler_free(struct common_sampler * gsmpl);
|
void common_sampler_free(struct common_sampler * gsmpl);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -73,16 +73,10 @@ int main(int argc, char ** argv) {
|
||||||
for (int32_t i = 0; i < n_parallel; ++i) {
|
for (int32_t i = 0; i < n_parallel; ++i) {
|
||||||
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
||||||
|
|
||||||
if (params.sampling.backend_sampling) {
|
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k));
|
||||||
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k));
|
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep));
|
||||||
llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp));
|
llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp));
|
||||||
llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed));
|
llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed));
|
||||||
} else {
|
|
||||||
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k));
|
|
||||||
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep));
|
|
||||||
llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp));
|
|
||||||
llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed));
|
|
||||||
}
|
|
||||||
|
|
||||||
sampler_configs.push_back({ i, smpl });
|
sampler_configs.push_back({ i, smpl });
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1212,7 +1212,7 @@ extern "C" {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler {
|
struct llama_sampler {
|
||||||
const struct llama_sampler_i * iface;
|
struct llama_sampler_i * iface;
|
||||||
|
|
||||||
llama_sampler_context_t ctx;
|
llama_sampler_context_t ctx;
|
||||||
};
|
};
|
||||||
|
|
@ -1220,7 +1220,7 @@ extern "C" {
|
||||||
LLAMA_API bool llama_set_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl);
|
LLAMA_API bool llama_set_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl);
|
||||||
|
|
||||||
// mirror of llama_sampler_i:
|
// mirror of llama_sampler_i:
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx);
|
LLAMA_API struct llama_sampler * llama_sampler_init ( struct llama_sampler_i * iface, llama_sampler_context_t ctx);
|
||||||
LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
|
LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
|
||||||
LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
|
LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
|
||||||
LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);
|
LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);
|
||||||
|
|
|
||||||
|
|
@ -2102,7 +2102,7 @@ void llm_graph_context::build_sampling() const {
|
||||||
ggml_build_forward_expand(gf, data.sampled);
|
ggml_build_forward_expand(gf, data.sampled);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (data.probs != nullptr) {
|
if (data.probs != nullptr) {
|
||||||
res->t_sampled_probs[seq_id] = data.probs;
|
res->t_sampled_probs[seq_id] = data.probs;
|
||||||
ggml_build_forward_expand(gf, data.probs);
|
ggml_build_forward_expand(gf, data.probs);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -349,7 +349,7 @@ static uint32_t get_rng_seed(uint32_t seed) {
|
||||||
// llama_sampler API
|
// llama_sampler API
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init(
|
struct llama_sampler * llama_sampler_init(
|
||||||
const struct llama_sampler_i * iface,
|
struct llama_sampler_i * iface,
|
||||||
llama_sampler_context_t ctx) {
|
llama_sampler_context_t ctx) {
|
||||||
return new llama_sampler {
|
return new llama_sampler {
|
||||||
/* .iface = */ iface,
|
/* .iface = */ iface,
|
||||||
|
|
@ -468,6 +468,42 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte
|
||||||
return token;
|
return token;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// backend sampling (empty iface)
|
||||||
|
|
||||||
|
static void llama_sampler_empty_backend_init(
|
||||||
|
struct llama_sampler * smpl,
|
||||||
|
ggml_backend_buffer_type_t buft) {
|
||||||
|
GGML_UNUSED(smpl);
|
||||||
|
GGML_UNUSED(buft);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void llama_sampler_empty_backend_accept(
|
||||||
|
struct llama_sampler * smpl,
|
||||||
|
ggml_context * ctx,
|
||||||
|
ggml_cgraph * gf,
|
||||||
|
struct ggml_tensor * selected_token) {
|
||||||
|
GGML_UNUSED(smpl);
|
||||||
|
GGML_UNUSED(ctx);
|
||||||
|
GGML_UNUSED(gf);
|
||||||
|
GGML_UNUSED(selected_token);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void llama_sampler_empty_backend_apply(
|
||||||
|
struct llama_sampler * smpl,
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_cgraph * gf,
|
||||||
|
struct llama_sampler_data * data) {
|
||||||
|
GGML_UNUSED(smpl);
|
||||||
|
GGML_UNUSED(ctx);
|
||||||
|
GGML_UNUSED(gf);
|
||||||
|
GGML_UNUSED(data);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void llama_sampler_empty_backend_set_input(struct llama_sampler * smpl) {
|
||||||
|
GGML_UNUSED(smpl);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// sampler chain
|
// sampler chain
|
||||||
|
|
||||||
static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
|
static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
|
||||||
|
|
@ -1171,7 +1207,7 @@ static void llama_sampler_top_p_backend_apply(
|
||||||
|
|
||||||
ggml_set_output(data->candidates);
|
ggml_set_output(data->candidates);
|
||||||
ggml_build_forward_expand(gf, data->candidates);
|
ggml_build_forward_expand(gf, data->candidates);
|
||||||
|
|
||||||
ggml_set_output(data->logits);
|
ggml_set_output(data->logits);
|
||||||
ggml_build_forward_expand(gf, data->logits);
|
ggml_build_forward_expand(gf, data->logits);
|
||||||
}
|
}
|
||||||
|
|
@ -1446,13 +1482,24 @@ static struct llama_sampler_i llama_sampler_typical_i = {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
|
struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
|
||||||
return llama_sampler_init(
|
auto * res = llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_typical_i,
|
/* .iface = */ &llama_sampler_typical_i,
|
||||||
/* .ctx = */ new llama_sampler_typical {
|
/* .ctx = */ new llama_sampler_typical {
|
||||||
/* .p = */ p,
|
/* .p = */ p,
|
||||||
/* .min_keep = */ min_keep,
|
/* .min_keep = */ min_keep,
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const bool is_empty = (p >= 1.0f);
|
||||||
|
|
||||||
|
if (is_empty) {
|
||||||
|
res->iface->backend_init = llama_sampler_empty_backend_init;
|
||||||
|
res->iface->backend_accept = llama_sampler_empty_backend_accept;
|
||||||
|
res->iface->backend_apply = llama_sampler_empty_backend_apply;
|
||||||
|
res->iface->backend_set_input = llama_sampler_empty_backend_set_input;
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
// temp
|
// temp
|
||||||
|
|
@ -1615,6 +1662,27 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
|
||||||
delete (llama_sampler_temp_ext *) smpl->ctx;
|
delete (llama_sampler_temp_ext *) smpl->ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void llama_sampler_temp_ext_backend_apply(
|
||||||
|
struct llama_sampler * smpl,
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_cgraph * gf,
|
||||||
|
struct llama_sampler_data * data) {
|
||||||
|
auto * ctx_data = (llama_sampler_temp *) smpl->ctx;
|
||||||
|
|
||||||
|
if (ctx_data->temp <= 0.0f) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * scaled = ggml_scale(ctx, data->logits, 1.0f / ctx_data->temp);
|
||||||
|
ggml_set_name(scaled, "temp_scaled");
|
||||||
|
|
||||||
|
// Make sure the scaled tensor is contiguous for subsequent operations
|
||||||
|
data->logits = ggml_cont(ctx, scaled);
|
||||||
|
ggml_set_name(data->logits, "temp_scaled_logits");
|
||||||
|
|
||||||
|
ggml_build_forward_expand(gf, data->logits);
|
||||||
|
}
|
||||||
|
|
||||||
static struct llama_sampler_i llama_sampler_temp_ext_i = {
|
static struct llama_sampler_i llama_sampler_temp_ext_i = {
|
||||||
/* .name = */ llama_sampler_temp_ext_name,
|
/* .name = */ llama_sampler_temp_ext_name,
|
||||||
/* .accept = */ nullptr,
|
/* .accept = */ nullptr,
|
||||||
|
|
@ -1629,7 +1697,7 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
|
struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
|
||||||
return llama_sampler_init(
|
auto * res = llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_temp_ext_i,
|
/* .iface = */ &llama_sampler_temp_ext_i,
|
||||||
/* .ctx = */ new llama_sampler_temp_ext {
|
/* .ctx = */ new llama_sampler_temp_ext {
|
||||||
/* .temp = */ temp,
|
/* .temp = */ temp,
|
||||||
|
|
@ -1637,6 +1705,14 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
|
||||||
/* .exponent = */ exponent,
|
/* .exponent = */ exponent,
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const bool is_backend = delta <= 0.0f;
|
||||||
|
|
||||||
|
if (is_backend) {
|
||||||
|
res->iface->backend_apply = llama_sampler_temp_ext_backend_apply;
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
// xtc
|
// xtc
|
||||||
|
|
@ -1727,8 +1803,9 @@ static struct llama_sampler_i llama_sampler_xtc_i = {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
|
struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
|
||||||
auto seed_cur = get_rng_seed(seed);
|
const auto seed_cur = get_rng_seed(seed);
|
||||||
return llama_sampler_init(
|
|
||||||
|
auto * res = llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_xtc_i,
|
/* .iface = */ &llama_sampler_xtc_i,
|
||||||
/* .ctx = */ new llama_sampler_xtc {
|
/* .ctx = */ new llama_sampler_xtc {
|
||||||
/* .probability = */ p,
|
/* .probability = */ p,
|
||||||
|
|
@ -1739,6 +1816,17 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep,
|
||||||
/* .rng = */ std::mt19937(seed_cur),
|
/* .rng = */ std::mt19937(seed_cur),
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const bool is_empty = (p <= 0.0f || t > 0.5f);
|
||||||
|
|
||||||
|
if (is_empty) {
|
||||||
|
res->iface->backend_init = llama_sampler_empty_backend_init;
|
||||||
|
res->iface->backend_accept = llama_sampler_empty_backend_accept;
|
||||||
|
res->iface->backend_apply = llama_sampler_empty_backend_apply;
|
||||||
|
res->iface->backend_set_input = llama_sampler_empty_backend_set_input;
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
// mirostat
|
// mirostat
|
||||||
|
|
@ -2280,7 +2368,7 @@ struct llama_sampler * llama_sampler_init_penalties(
|
||||||
float penalty_present) {
|
float penalty_present) {
|
||||||
penalty_last_n = std::max(penalty_last_n, 0);
|
penalty_last_n = std::max(penalty_last_n, 0);
|
||||||
|
|
||||||
return llama_sampler_init(
|
auto * res = llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_penalties_i,
|
/* .iface = */ &llama_sampler_penalties_i,
|
||||||
/* .ctx = */ new llama_sampler_penalties {
|
/* .ctx = */ new llama_sampler_penalties {
|
||||||
/* .penalty_last_n = */ penalty_last_n,
|
/* .penalty_last_n = */ penalty_last_n,
|
||||||
|
|
@ -2291,6 +2379,17 @@ struct llama_sampler * llama_sampler_init_penalties(
|
||||||
/* .token_count = */ {},
|
/* .token_count = */ {},
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const bool is_empty = (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f));
|
||||||
|
|
||||||
|
if (is_empty) {
|
||||||
|
res->iface->backend_init = llama_sampler_empty_backend_init;
|
||||||
|
res->iface->backend_accept = llama_sampler_empty_backend_accept;
|
||||||
|
res->iface->backend_apply = llama_sampler_empty_backend_apply;
|
||||||
|
res->iface->backend_set_input = llama_sampler_empty_backend_set_input;
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
// top-n-sigma
|
// top-n-sigma
|
||||||
|
|
@ -2317,9 +2416,7 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t
|
||||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||||
// Only count non-negative infinity values
|
// Only count non-negative infinity values
|
||||||
if (cur_p->data[i].logit != -INFINITY) {
|
if (cur_p->data[i].logit != -INFINITY) {
|
||||||
if (cur_p->data[i].logit > max) {
|
max = std::max(max, cur_p->data[i].logit);
|
||||||
max = cur_p->data[i].logit;
|
|
||||||
}
|
|
||||||
logits_sum += cur_p->data[i].logit;
|
logits_sum += cur_p->data[i].logit;
|
||||||
valid_count++;
|
valid_count++;
|
||||||
}
|
}
|
||||||
|
|
@ -2369,12 +2466,23 @@ static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
|
struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
|
||||||
return llama_sampler_init(
|
auto * res = llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_top_n_sigma_i,
|
/* .iface = */ &llama_sampler_top_n_sigma_i,
|
||||||
/* .ctx = */ new llama_sampler_top_n_sigma {
|
/* .ctx = */ new llama_sampler_top_n_sigma {
|
||||||
/* .n = */ n,
|
/* .n = */ n,
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const bool is_empty = (n <= 0.0f);
|
||||||
|
|
||||||
|
if (is_empty) {
|
||||||
|
res->iface->backend_init = llama_sampler_empty_backend_init;
|
||||||
|
res->iface->backend_accept = llama_sampler_empty_backend_accept;
|
||||||
|
res->iface->backend_apply = llama_sampler_empty_backend_apply;
|
||||||
|
res->iface->backend_set_input = llama_sampler_empty_backend_set_input;
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
// DRY
|
// DRY
|
||||||
|
|
@ -2733,7 +2841,7 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return llama_sampler_init(
|
auto * res = llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_dry_i,
|
/* .iface = */ &llama_sampler_dry_i,
|
||||||
/* .ctx = */ new llama_sampler_dry {
|
/* .ctx = */ new llama_sampler_dry {
|
||||||
/* .total_context_size = */ n_ctx_train,
|
/* .total_context_size = */ n_ctx_train,
|
||||||
|
|
@ -2747,6 +2855,15 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
|
||||||
/* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
|
/* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
if (!dry_enabled) {
|
||||||
|
res->iface->backend_init = llama_sampler_empty_backend_init;
|
||||||
|
res->iface->backend_accept = llama_sampler_empty_backend_accept;
|
||||||
|
res->iface->backend_apply = llama_sampler_empty_backend_apply;
|
||||||
|
res->iface->backend_set_input = llama_sampler_empty_backend_set_input;
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
// wrapper for test-sampling.cpp
|
// wrapper for test-sampling.cpp
|
||||||
|
|
@ -2854,6 +2971,8 @@ static void llama_sampler_logit_bias_backend_apply(
|
||||||
|
|
||||||
// Add the sparse logit logit_bias to the logits
|
// Add the sparse logit logit_bias to the logits
|
||||||
struct ggml_tensor * logit_biased = ggml_add_inplace(ctx, data->logits, sctx->inp_logit_bias);
|
struct ggml_tensor * logit_biased = ggml_add_inplace(ctx, data->logits, sctx->inp_logit_bias);
|
||||||
|
data->logits = logit_biased;
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, logit_biased);
|
ggml_build_forward_expand(gf, logit_biased);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue