add power law case to `common_sampler_init`, add sampler name mappings
This commit is contained in:
parent
b3aea57768
commit
cd7de7c7a8
|
|
@ -243,6 +243,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
params.logit_bias.data()));
|
||||
|
||||
if (params.mirostat == 0) {
|
||||
// if this flag is set, we will not need to add `dist` at the end of the sampler chain
|
||||
bool has_distribution_sampler = false;
|
||||
|
||||
for (const auto & cnstr : params.samplers) {
|
||||
switch (cnstr) {
|
||||
case COMMON_SAMPLER_TYPE_DRY:
|
||||
|
|
@ -253,7 +256,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
c_breakers.push_back(str.c_str());
|
||||
}
|
||||
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
|
|
@ -283,11 +286,18 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
case COMMON_SAMPLER_TYPE_PENALTIES:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_POWER_LAW:
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_power_law (params.power_law_target, params.power_law_target_range, params.power_law_window_size, params.seed));
|
||||
has_distribution_sampler = true;
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unknown sampler type");
|
||||
}
|
||||
}
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
||||
// only add `dist` to the end of the chain if no other distribution samplers were added
|
||||
if (!has_distribution_sampler) {
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
||||
}
|
||||
} else if (params.mirostat == 1) {
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
|
||||
|
|
@ -586,6 +596,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
|
|||
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
|
||||
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
|
||||
{ "penalties", COMMON_SAMPLER_TYPE_PENALTIES },
|
||||
{ "power_law", COMMON_SAMPLER_TYPE_POWER_LAW },
|
||||
};
|
||||
|
||||
// since samplers names are written multiple ways
|
||||
|
|
@ -601,6 +612,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
|
|||
{ "typ", COMMON_SAMPLER_TYPE_TYPICAL_P },
|
||||
{ "min-p", COMMON_SAMPLER_TYPE_MIN_P },
|
||||
{ "temp", COMMON_SAMPLER_TYPE_TEMPERATURE },
|
||||
{ "power-law", COMMON_SAMPLER_TYPE_POWER_LAW },
|
||||
};
|
||||
|
||||
std::vector<common_sampler_type> samplers;
|
||||
|
|
|
|||
Loading…
Reference in New Issue