common : add get_active_samplers function to check enabled samplers
This commit adds a function to check if a sampler is actually enabled, meaning that it does not have values that disables its effect. This is then used by the backend samplers initialization to avoid considering samplers that are not enabled when determining the split point between them. The motivation for this is that this allows the default sampler chain for `--samplers` to be used and any sampler that is not enabled will not cause the backend samplers to be skipped. For example, before this change if the penalties sampler was included in the samplers list but had default values that disable it, it would cause the backend samplers to be skipped entirely. This commit also contains some refactoring to remove some code duplication.
This commit is contained in:
parent
4fea191c66
commit
0f7805f32a
|
|
@ -1019,7 +1019,6 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||||
|
|
||||||
auto cparams = common_context_params_to_llama(params);
|
auto cparams = common_context_params_to_llama(params);
|
||||||
|
|
||||||
// backend sampling initialization
|
|
||||||
if (params.sampling.backend_sampling) {
|
if (params.sampling.backend_sampling) {
|
||||||
llama_sampler * backend_chain = common_sampler_backend_init(model, params.sampling);
|
llama_sampler * backend_chain = common_sampler_backend_init(model, params.sampling);
|
||||||
if (backend_chain != nullptr) {
|
if (backend_chain != nullptr) {
|
||||||
|
|
|
||||||
|
|
@ -165,10 +165,6 @@ struct common_sampler {
|
||||||
mutable int64_t t_total_us = 0;
|
mutable int64_t t_total_us = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
static bool sampler_enabled(const struct common_params_sampling & params, enum common_sampler_type type) {
|
|
||||||
return std::find(params.samplers.begin(), params.samplers.end(), type) != params.samplers.end();
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool sampler_backend_supported(enum common_sampler_type type) {
|
static bool sampler_backend_supported(enum common_sampler_type type) {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||||
|
|
@ -180,10 +176,100 @@ static bool sampler_backend_supported(enum common_sampler_type type) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool is_sampler_enabled(enum common_sampler_type type, const struct common_params_sampling & params) {
|
||||||
|
switch (type) {
|
||||||
|
case COMMON_SAMPLER_TYPE_PENALTIES:
|
||||||
|
if (params.penalty_last_n == 64 &&
|
||||||
|
fabs(params.penalty_repeat) <= 1.0f &&
|
||||||
|
fabs(params.penalty_freq) <= 0.0f &&
|
||||||
|
fabs(params.penalty_present) <= 0.0f) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case COMMON_SAMPLER_TYPE_DRY:
|
||||||
|
if (params.dry_multiplier == 0.0f && params.dry_base == 1.75f) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
||||||
|
if (params.typ_p == 1.0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
|
||||||
|
if (params.top_n_sigma == -1.0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||||
|
if (params.top_k <= 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||||
|
if (params.temp <= 0.0f) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||||
|
if (params.min_p <= 0.0f) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case COMMON_SAMPLER_TYPE_TOP_P:
|
||||||
|
if (params.top_p >= 1.0f) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case COMMON_SAMPLER_TYPE_XTC:
|
||||||
|
if (params.xtc_probability == 0.0f && params.xtc_threshold == 0.10f) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
static bool has_logit_bias(const struct common_params_sampling & params) {
|
static bool has_logit_bias(const struct common_params_sampling & params) {
|
||||||
return !params.logit_bias.empty();
|
return !params.logit_bias.empty();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct active_samplers {
|
||||||
|
std::vector<common_sampler_type> backend_samplers;
|
||||||
|
std::vector<common_sampler_type> cpu_samplers;
|
||||||
|
};
|
||||||
|
|
||||||
|
static struct active_samplers get_active_samplers(const struct common_params_sampling & params) {
|
||||||
|
struct active_samplers result;
|
||||||
|
|
||||||
|
if (params.mirostat != 0) {
|
||||||
|
// Mirostat is CPU-only and overrides other samplers
|
||||||
|
for (const auto & sampler_type : params.samplers) {
|
||||||
|
if (is_sampler_enabled(sampler_type, params)) {
|
||||||
|
result.cpu_samplers.push_back(sampler_type);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool backend_supported = params.backend_sampling;
|
||||||
|
|
||||||
|
for (const auto & sampler_type : params.samplers) {
|
||||||
|
if (!is_sampler_enabled(sampler_type, params)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (backend_supported && sampler_backend_supported(sampler_type)) {
|
||||||
|
result.backend_samplers.push_back(sampler_type);
|
||||||
|
} else {
|
||||||
|
result.cpu_samplers.push_back(sampler_type);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
std::string common_params_sampling::print() const {
|
std::string common_params_sampling::print() const {
|
||||||
char result[1024];
|
char result[1024];
|
||||||
|
|
||||||
|
|
@ -200,6 +286,14 @@ std::string common_params_sampling::print() const {
|
||||||
return std::string(result);
|
return std::string(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct backend_chain_data {
|
||||||
|
struct llama_sampler * chain;
|
||||||
|
size_t count;
|
||||||
|
};
|
||||||
|
|
||||||
|
static struct backend_chain_data backend_samplers_init(const struct llama_model * model, const struct common_params_sampling & params,
|
||||||
|
struct active_samplers get_active_samplers);
|
||||||
|
|
||||||
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
|
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
|
||||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||||
|
|
||||||
|
|
@ -277,69 +371,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||||
/* .cur_p = */ {},
|
/* .cur_p = */ {},
|
||||||
};
|
};
|
||||||
|
|
||||||
size_t backend_sampler_count = 0;
|
struct active_samplers active_samplers = get_active_samplers(params);
|
||||||
if (params.backend_sampling && params.mirostat == 0) {
|
backend_chain_data backend_data = backend_samplers_init(model, params, active_samplers);
|
||||||
if (has_logit_bias(params)) {
|
result->backend_chain = backend_data.chain;
|
||||||
backend_sampler_count++;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find the longest contiguous chain of backend-supported samplers from the start
|
|
||||||
for (const auto & sampler_type : params.samplers) {
|
|
||||||
if (sampler_backend_supported(sampler_type)) {
|
|
||||||
backend_sampler_count++;
|
|
||||||
} else {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the samplers combination is supported then we can build the backend chain.
|
|
||||||
if (backend_sampler_count > 0 || (params.backend_sampling && has_logit_bias(params))) {
|
|
||||||
llama_sampler_chain_params backend_params = llama_sampler_chain_default_params();
|
|
||||||
backend_params.no_perf = params.no_perf;
|
|
||||||
result->backend_chain = llama_sampler_chain_init(backend_params);
|
|
||||||
|
|
||||||
if (has_logit_bias(params)) {
|
|
||||||
llama_sampler_chain_add(result->backend_chain,
|
|
||||||
llama_sampler_backend_init_logit_bias(
|
|
||||||
llama_vocab_n_tokens(vocab),
|
|
||||||
params.logit_bias.size(),
|
|
||||||
params.logit_bias.data()));
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t backend_idx = 0;
|
|
||||||
for (const auto & sampler_type : params.samplers) {
|
|
||||||
if (backend_idx >= backend_sampler_count - has_logit_bias(params)) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
switch (sampler_type) {
|
|
||||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
|
||||||
if (params.top_k > 0) {
|
|
||||||
llama_sampler_chain_add(result->backend_chain, llama_sampler_backend_init_top_k(params.top_k));
|
|
||||||
}
|
|
||||||
backend_idx++;
|
|
||||||
break;
|
|
||||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
|
||||||
if (params.temp > 0.0f) {
|
|
||||||
llama_sampler_chain_add(result->backend_chain, llama_sampler_backend_init_temp(params.temp));
|
|
||||||
}
|
|
||||||
backend_idx++;
|
|
||||||
break;
|
|
||||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
|
||||||
if (params.min_p > 0.0f) {
|
|
||||||
llama_sampler_chain_add(result->backend_chain, llama_sampler_backend_init_min_p(params.min_p));
|
|
||||||
}
|
|
||||||
backend_idx++;
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
GGML_ASSERT(false && "unsupported backend sampler");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t cpu_start_idx = backend_sampler_count - has_logit_bias(params);
|
|
||||||
bool cpu_has_samplers = cpu_start_idx < params.samplers.size();
|
|
||||||
|
|
||||||
// Build CPU chain
|
// Build CPU chain
|
||||||
if (!params.backend_sampling || !has_logit_bias(params)) {
|
if (!params.backend_sampling || !has_logit_bias(params)) {
|
||||||
|
|
@ -352,8 +386,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||||
|
|
||||||
if (params.mirostat == 0) {
|
if (params.mirostat == 0) {
|
||||||
// Add remaining CPU samplers
|
// Add remaining CPU samplers
|
||||||
for (size_t i = cpu_start_idx; i < params.samplers.size(); i++) {
|
for (const auto & cnstr : active_samplers.cpu_samplers) {
|
||||||
const auto & cnstr = params.samplers[i];
|
|
||||||
switch (cnstr) {
|
switch (cnstr) {
|
||||||
case COMMON_SAMPLER_TYPE_DRY:
|
case COMMON_SAMPLER_TYPE_DRY:
|
||||||
{
|
{
|
||||||
|
|
@ -398,10 +431,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If all samplers are on backend, add dist to backend; otherwise add to CPU
|
if (!active_samplers.cpu_samplers.empty()) {
|
||||||
if (result->backend_chain && !cpu_has_samplers) {
|
|
||||||
llama_sampler_chain_add(result->backend_chain, llama_sampler_backend_init_dist(params.seed));
|
|
||||||
} else {
|
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
||||||
}
|
}
|
||||||
} else if (params.mirostat == 1) {
|
} else if (params.mirostat == 1) {
|
||||||
|
|
@ -417,35 +447,17 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_sampler * common_sampler_backend_init(const struct llama_model * model, const struct common_params_sampling & params) {
|
|
||||||
if (!params.backend_sampling || params.mirostat != 0) {
|
static struct backend_chain_data backend_samplers_init(const struct llama_model * model, const struct common_params_sampling & params,
|
||||||
return nullptr;
|
struct active_samplers active_samplers) {
|
||||||
|
if (active_samplers.backend_samplers.empty()) {
|
||||||
|
return { nullptr, 0 };
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||||
|
|
||||||
// Determine the split point for backend sampling using the same logic as common_sampler_init
|
|
||||||
size_t backend_sampler_count = 0;
|
|
||||||
if (has_logit_bias(params)) {
|
|
||||||
backend_sampler_count++;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find the longest contiguous chain of backend-supported samplers from the start
|
|
||||||
for (const auto & sampler_type : params.samplers) {
|
|
||||||
if (sampler_backend_supported(sampler_type)) {
|
|
||||||
backend_sampler_count++;
|
|
||||||
} else {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (backend_sampler_count == 0 && !has_logit_bias(params)) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
|
llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
|
||||||
chain_params.no_perf = params.no_perf;
|
chain_params.no_perf = params.no_perf;
|
||||||
|
|
||||||
struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
|
struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
|
||||||
|
|
||||||
// Add logit_bias to backend chain if present
|
// Add logit_bias to backend chain if present
|
||||||
|
|
@ -456,46 +468,32 @@ struct llama_sampler * common_sampler_backend_init(const struct llama_model * mo
|
||||||
params.logit_bias.data()));
|
params.logit_bias.data()));
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t backend_idx = 0;
|
for (const auto & sampler_type : active_samplers.backend_samplers) {
|
||||||
for (const auto & sampler_type : params.samplers) {
|
|
||||||
if (backend_idx >= backend_sampler_count - has_logit_bias(params)) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
switch (sampler_type) {
|
switch (sampler_type) {
|
||||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||||
if (params.top_k > 0) {
|
llama_sampler_chain_add(chain, llama_sampler_backend_init_top_k(params.top_k));
|
||||||
llama_sampler_chain_add(chain, llama_sampler_backend_init_top_k(params.top_k));
|
|
||||||
}
|
|
||||||
backend_idx++;
|
|
||||||
break;
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||||
if (params.temp > 0.0f) {
|
llama_sampler_chain_add(chain, llama_sampler_backend_init_temp(params.temp));
|
||||||
llama_sampler_chain_add(chain, llama_sampler_backend_init_temp(params.temp));
|
|
||||||
}
|
|
||||||
backend_idx++;
|
|
||||||
break;
|
break;
|
||||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||||
if (params.min_p > 0.0f) {
|
llama_sampler_chain_add(chain, llama_sampler_backend_init_min_p(params.min_p));
|
||||||
llama_sampler_chain_add(chain, llama_sampler_backend_init_min_p(params.min_p));
|
|
||||||
}
|
|
||||||
backend_idx++;
|
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false && "unsupported backend sampler");
|
GGML_ASSERT(false && "unsupported backend sampler");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine if we should add dist sampler to backend chain
|
if (active_samplers.cpu_samplers.empty()) {
|
||||||
// Only add it if all samplers from params.samplers are on the backend
|
|
||||||
size_t cpu_start_idx = backend_sampler_count - has_logit_bias(params);
|
|
||||||
bool cpu_has_samplers = cpu_start_idx < params.samplers.size();
|
|
||||||
|
|
||||||
if (!cpu_has_samplers) {
|
|
||||||
llama_sampler_chain_add(chain, llama_sampler_backend_init_dist(params.seed));
|
llama_sampler_chain_add(chain, llama_sampler_backend_init_dist(params.seed));
|
||||||
}
|
}
|
||||||
|
|
||||||
return chain;
|
return { chain, active_samplers.backend_samplers.size() + has_logit_bias(params) };
|
||||||
|
}
|
||||||
|
|
||||||
|
struct llama_sampler * common_sampler_backend_init(const struct llama_model * model, const struct common_params_sampling & params) {
|
||||||
|
struct active_samplers active_samplers = get_active_samplers(params);
|
||||||
|
return backend_samplers_init(model, params, active_samplers).chain;
|
||||||
}
|
}
|
||||||
|
|
||||||
void common_sampler_free(struct common_sampler * gsmpl) {
|
void common_sampler_free(struct common_sampler * gsmpl) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue