diff --git a/common/common.h b/common/common.h index 9b53d2b56f..127a8cff1d 100644 --- a/common/common.h +++ b/common/common.h @@ -212,18 +212,12 @@ struct common_params_sampling { std::vector logit_bias; // logit biases to apply std::vector logit_bias_eog; // pre-calculated logit biases for EOG tokens - bool backend_sampling = false; // enable backend sampling + bool backend_sampling = false; bool has_logit_bias() const { return !logit_bias.empty(); } - bool is_disabled(enum common_sampler_type type) const; - - // remove disabled samplers - // TODO: temporary until all samplers have llama_sampler_backend_ API [LLAMA_SAMPLER_BACKEND] - void filter_disabled(); - // print the parameters into a string std::string print() const; }; @@ -661,7 +655,7 @@ std::vector fs_list_files(const std::string & path); struct common_sampler; -// note: defines object's lifetime +// note: defines the model, context, samplers, ets. lifetimes struct common_init_result { common_init_result(common_params & params); ~common_init_result(); diff --git a/common/sampling.cpp b/common/sampling.cpp index 2a6f57cd74..b7dfed547b 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -163,84 +163,6 @@ struct common_sampler { mutable int64_t t_total_us = 0; }; -// TODO: temporary until all samplers have llama_sampler_backend_ API [LLAMA_SAMPLER_BACKEND] -static bool common_sampler_type_has_backend_support(enum common_sampler_type type) { - switch (type) { - case COMMON_SAMPLER_TYPE_TOP_K: - case COMMON_SAMPLER_TYPE_TEMPERATURE: - case COMMON_SAMPLER_TYPE_MIN_P: - case COMMON_SAMPLER_TYPE_TOP_P: - return true; - default: - return false; - } -} - -bool common_params_sampling::is_disabled(enum common_sampler_type type) const { - switch (type) { - case COMMON_SAMPLER_TYPE_PENALTIES: - if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) { - return true; - } - break; - case COMMON_SAMPLER_TYPE_DRY: - if (dry_multiplier == 0.0f || dry_base < 1.0f || dry_penalty_last_n == 0) { - return true; - } - break; - case COMMON_SAMPLER_TYPE_TYPICAL_P: - if (typ_p >= 1.0) { - return true; - } - break; - case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: - if (top_n_sigma <= 0.0) { - return true; - } - break; - case COMMON_SAMPLER_TYPE_TOP_K: - if (top_k <= 0) { - return true; - } - break; - case COMMON_SAMPLER_TYPE_TEMPERATURE: - if (dynatemp_range <= 0.0f) { - return true; - } - break; - case COMMON_SAMPLER_TYPE_MIN_P: - if (min_p <= 0.0f) { - return true; - } - break; - case COMMON_SAMPLER_TYPE_TOP_P: - if (top_p >= 1.0f) { - return true; - } - break; - case COMMON_SAMPLER_TYPE_XTC: - if (xtc_probability <= 0.0f || xtc_threshold == 0.50f) { - return true; - } - break; - default: - break; - } - - return false; -} - -void common_params_sampling::filter_disabled() { - for (auto it = samplers.begin(); it != samplers.end();) { - if (is_disabled(*it)) { - LOG_WRN("%s: removing disabled sampler %s\n", __func__, common_sampler_type_to_str(*it).c_str()); - it = samplers.erase(it); - } else { - ++it; - } - } -} - std::string common_params_sampling::print() const { char result[1024]; @@ -257,7 +179,7 @@ std::string common_params_sampling::print() const { return std::string(result); } -struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params) { +struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) { const llama_vocab * vocab = llama_model_get_vocab(model); llama_sampler_chain_params lparams = llama_sampler_chain_default_params(); @@ -324,11 +246,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st } } - // TODO: temporary until all samplers have llama_sampler_backend_ API [LLAMA_SAMPLER_BACKEND] - if (params.backend_sampling) { - params.filter_disabled(); - } - auto * result = new common_sampler { /* .params = */ params, /* .grmr = */ grmr, @@ -339,54 +256,13 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st /* .cur_p = */ {}, }; - size_t idx_smpl = 0; - - bool is_backend = true; - - is_backend = is_backend && params.backend_sampling; - is_backend = is_backend && (params.samplers.size() == 0 || common_sampler_type_has_backend_support(params.samplers[idx_smpl])); - + std::vector samplers; if (params.has_logit_bias()) { - llama_sampler_chain_add(is_backend ? result->chain_backend : result->chain, - llama_sampler_init_logit_bias( - llama_vocab_n_tokens(vocab), - params.logit_bias.size(), - params.logit_bias.data())); + samplers.push_back(llama_sampler_init_logit_bias(llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data())); } if (params.mirostat == 0) { - // backend samplers are added first - while (is_backend && idx_smpl < params.samplers.size()) { - const auto & cnstr = params.samplers[idx_smpl++]; - - if (!common_sampler_type_has_backend_support(cnstr)) { - is_backend = false; - --idx_smpl; - break; - } - - switch (cnstr) { - case COMMON_SAMPLER_TYPE_TOP_K: - llama_sampler_chain_add(result->chain_backend, llama_sampler_init_top_k(params.top_k)); - break; - case COMMON_SAMPLER_TYPE_TEMPERATURE: - llama_sampler_chain_add(result->chain_backend, llama_sampler_init_temp(params.temp)); - break; - case COMMON_SAMPLER_TYPE_MIN_P: - llama_sampler_chain_add(result->chain_backend, llama_sampler_init_min_p(params.min_p, params.min_keep)); - break; - case COMMON_SAMPLER_TYPE_TOP_P: - llama_sampler_chain_add(result->chain_backend, llama_sampler_init_top_p(params.top_p, params.min_keep)); - break; - default: - GGML_ASSERT(false && "unsupported backend sampler"); - } - } - - // Add remaining CPU samplers - while (idx_smpl < params.samplers.size()) { - const auto & cnstr = params.samplers[idx_smpl++]; - + for (const auto & cnstr : params.samplers) { switch (cnstr) { case COMMON_SAMPLER_TYPE_DRY: { @@ -396,52 +272,63 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st c_breakers.push_back(str.c_str()); } - llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); + samplers.push_back(llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); } break; case COMMON_SAMPLER_TYPE_TOP_K: - llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); + samplers.push_back(llama_sampler_init_top_k (params.top_k)); break; case COMMON_SAMPLER_TYPE_TOP_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); + samplers.push_back(llama_sampler_init_top_p (params.top_p, params.min_keep)); break; case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: - llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma)); + samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma)); break; case COMMON_SAMPLER_TYPE_MIN_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); + samplers.push_back(llama_sampler_init_min_p (params.min_p, params.min_keep)); break; case COMMON_SAMPLER_TYPE_XTC: - llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); + samplers.push_back(llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); break; case COMMON_SAMPLER_TYPE_TYPICAL_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); + samplers.push_back(llama_sampler_init_typical (params.typ_p, params.min_keep)); break; case COMMON_SAMPLER_TYPE_TEMPERATURE: - llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); + samplers.push_back(llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); break; case COMMON_SAMPLER_TYPE_INFILL: - llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab)); + samplers.push_back(llama_sampler_init_infill (vocab)); break; case COMMON_SAMPLER_TYPE_PENALTIES: - llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); + samplers.push_back(llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); break; default: GGML_ASSERT(false && "unknown sampler type"); } } - llama_sampler_chain_add(is_backend ? result->chain_backend : result->chain, llama_sampler_init_dist(params.seed)); + samplers.push_back(llama_sampler_init_dist(params.seed)); } else if (params.mirostat == 1) { - llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); - llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); + samplers.push_back(llama_sampler_init_temp(params.temp)); + samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); } else if (params.mirostat == 2) { - llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); - llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta)); + samplers.push_back(llama_sampler_init_temp(params.temp)); + samplers.push_back(llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta)); } else { GGML_ASSERT(false && "unknown mirostat version"); } + bool is_backend = params.backend_sampling; + + // split in two chains: backend -> CPU + for (auto * smpl : samplers) { + if (!smpl->iface->backend_apply) { + is_backend = false; + } + + llama_sampler_chain_add(is_backend ? result->chain_backend : result->chain, smpl); + } + return result; } diff --git a/common/sampling.h b/common/sampling.h index 06f27923a0..04b56dbbed 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -36,8 +36,7 @@ struct common_sampler; // llama_sampler API overloads -// TODO: params should become const again [LLAMA_SAMPLER_BACKEND] -struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params); +struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params); void common_sampler_free(struct common_sampler * gsmpl); diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 97cad5d260..0eb76316cb 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -73,16 +73,10 @@ int main(int argc, char ** argv) { for (int32_t i = 0; i < n_parallel; ++i) { llama_sampler * smpl = llama_sampler_chain_init(sparams); - if (params.sampling.backend_sampling) { - llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k)); - llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp)); - llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed)); - } else { - llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k)); - llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep)); - llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp)); - llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed)); - } + llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k)); + llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep)); + llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp)); + llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed)); sampler_configs.push_back({ i, smpl }); } diff --git a/include/llama.h b/include/llama.h index 01eca7609a..f6926b6063 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1212,7 +1212,7 @@ extern "C" { }; struct llama_sampler { - const struct llama_sampler_i * iface; + struct llama_sampler_i * iface; llama_sampler_context_t ctx; }; @@ -1220,7 +1220,7 @@ extern "C" { LLAMA_API bool llama_set_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl); // mirror of llama_sampler_i: - LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx); + LLAMA_API struct llama_sampler * llama_sampler_init ( struct llama_sampler_i * iface, llama_sampler_context_t ctx); LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl); LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token); LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 90e0a2658a..a621c4ebf5 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -2102,7 +2102,7 @@ void llm_graph_context::build_sampling() const { ggml_build_forward_expand(gf, data.sampled); } - if (data.probs != nullptr) { + if (data.probs != nullptr) { res->t_sampled_probs[seq_id] = data.probs; ggml_build_forward_expand(gf, data.probs); } diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index fd4e770e3c..ca6c3670b1 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -349,7 +349,7 @@ static uint32_t get_rng_seed(uint32_t seed) { // llama_sampler API struct llama_sampler * llama_sampler_init( - const struct llama_sampler_i * iface, + struct llama_sampler_i * iface, llama_sampler_context_t ctx) { return new llama_sampler { /* .iface = */ iface, @@ -468,6 +468,42 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte return token; } +// backend sampling (empty iface) + +static void llama_sampler_empty_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + GGML_UNUSED(smpl); + GGML_UNUSED(buft); +} + +static void llama_sampler_empty_backend_accept( + struct llama_sampler * smpl, + ggml_context * ctx, + ggml_cgraph * gf, + struct ggml_tensor * selected_token) { + GGML_UNUSED(smpl); + GGML_UNUSED(ctx); + GGML_UNUSED(gf); + GGML_UNUSED(selected_token); +} + +static void llama_sampler_empty_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + GGML_UNUSED(smpl); + GGML_UNUSED(ctx); + GGML_UNUSED(gf); + GGML_UNUSED(data); +} + +static void llama_sampler_empty_backend_set_input(struct llama_sampler * smpl) { + GGML_UNUSED(smpl); +} + + // sampler chain static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) { @@ -1171,7 +1207,7 @@ static void llama_sampler_top_p_backend_apply( ggml_set_output(data->candidates); ggml_build_forward_expand(gf, data->candidates); - + ggml_set_output(data->logits); ggml_build_forward_expand(gf, data->logits); } @@ -1446,13 +1482,24 @@ static struct llama_sampler_i llama_sampler_typical_i = { }; struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) { - return llama_sampler_init( + auto * res = llama_sampler_init( /* .iface = */ &llama_sampler_typical_i, /* .ctx = */ new llama_sampler_typical { /* .p = */ p, /* .min_keep = */ min_keep, } ); + + const bool is_empty = (p >= 1.0f); + + if (is_empty) { + res->iface->backend_init = llama_sampler_empty_backend_init; + res->iface->backend_accept = llama_sampler_empty_backend_accept; + res->iface->backend_apply = llama_sampler_empty_backend_apply; + res->iface->backend_set_input = llama_sampler_empty_backend_set_input; + } + + return res; } // temp @@ -1615,6 +1662,27 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) { delete (llama_sampler_temp_ext *) smpl->ctx; } +static void llama_sampler_temp_ext_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + auto * ctx_data = (llama_sampler_temp *) smpl->ctx; + + if (ctx_data->temp <= 0.0f) { + return; + } + + struct ggml_tensor * scaled = ggml_scale(ctx, data->logits, 1.0f / ctx_data->temp); + ggml_set_name(scaled, "temp_scaled"); + + // Make sure the scaled tensor is contiguous for subsequent operations + data->logits = ggml_cont(ctx, scaled); + ggml_set_name(data->logits, "temp_scaled_logits"); + + ggml_build_forward_expand(gf, data->logits); +} + static struct llama_sampler_i llama_sampler_temp_ext_i = { /* .name = */ llama_sampler_temp_ext_name, /* .accept = */ nullptr, @@ -1629,7 +1697,7 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = { }; struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) { - return llama_sampler_init( + auto * res = llama_sampler_init( /* .iface = */ &llama_sampler_temp_ext_i, /* .ctx = */ new llama_sampler_temp_ext { /* .temp = */ temp, @@ -1637,6 +1705,14 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa /* .exponent = */ exponent, } ); + + const bool is_backend = delta <= 0.0f; + + if (is_backend) { + res->iface->backend_apply = llama_sampler_temp_ext_backend_apply; + } + + return res; } // xtc @@ -1727,8 +1803,9 @@ static struct llama_sampler_i llama_sampler_xtc_i = { }; struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) { - auto seed_cur = get_rng_seed(seed); - return llama_sampler_init( + const auto seed_cur = get_rng_seed(seed); + + auto * res = llama_sampler_init( /* .iface = */ &llama_sampler_xtc_i, /* .ctx = */ new llama_sampler_xtc { /* .probability = */ p, @@ -1739,6 +1816,17 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, /* .rng = */ std::mt19937(seed_cur), } ); + + const bool is_empty = (p <= 0.0f || t > 0.5f); + + if (is_empty) { + res->iface->backend_init = llama_sampler_empty_backend_init; + res->iface->backend_accept = llama_sampler_empty_backend_accept; + res->iface->backend_apply = llama_sampler_empty_backend_apply; + res->iface->backend_set_input = llama_sampler_empty_backend_set_input; + } + + return res; } // mirostat @@ -2280,7 +2368,7 @@ struct llama_sampler * llama_sampler_init_penalties( float penalty_present) { penalty_last_n = std::max(penalty_last_n, 0); - return llama_sampler_init( + auto * res = llama_sampler_init( /* .iface = */ &llama_sampler_penalties_i, /* .ctx = */ new llama_sampler_penalties { /* .penalty_last_n = */ penalty_last_n, @@ -2291,6 +2379,17 @@ struct llama_sampler * llama_sampler_init_penalties( /* .token_count = */ {}, } ); + + const bool is_empty = (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)); + + if (is_empty) { + res->iface->backend_init = llama_sampler_empty_backend_init; + res->iface->backend_accept = llama_sampler_empty_backend_accept; + res->iface->backend_apply = llama_sampler_empty_backend_apply; + res->iface->backend_set_input = llama_sampler_empty_backend_set_input; + } + + return res; } // top-n-sigma @@ -2317,9 +2416,7 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t for (size_t i = 0; i < cur_p->size; ++i) { // Only count non-negative infinity values if (cur_p->data[i].logit != -INFINITY) { - if (cur_p->data[i].logit > max) { - max = cur_p->data[i].logit; - } + max = std::max(max, cur_p->data[i].logit); logits_sum += cur_p->data[i].logit; valid_count++; } @@ -2369,12 +2466,23 @@ static struct llama_sampler_i llama_sampler_top_n_sigma_i = { }; struct llama_sampler * llama_sampler_init_top_n_sigma(float n) { - return llama_sampler_init( + auto * res = llama_sampler_init( /* .iface = */ &llama_sampler_top_n_sigma_i, /* .ctx = */ new llama_sampler_top_n_sigma { /* .n = */ n, } ); + + const bool is_empty = (n <= 0.0f); + + if (is_empty) { + res->iface->backend_init = llama_sampler_empty_backend_init; + res->iface->backend_accept = llama_sampler_empty_backend_accept; + res->iface->backend_apply = llama_sampler_empty_backend_apply; + res->iface->backend_set_input = llama_sampler_empty_backend_set_input; + } + + return res; } // DRY @@ -2733,7 +2841,7 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, } } - return llama_sampler_init( + auto * res = llama_sampler_init( /* .iface = */ &llama_sampler_dry_i, /* .ctx = */ new llama_sampler_dry { /* .total_context_size = */ n_ctx_train, @@ -2747,6 +2855,15 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, /* .last_tokens = */ dry_enabled ? ring_buffer(effective_dry_penalty_last_n) : ring_buffer(0), } ); + + if (!dry_enabled) { + res->iface->backend_init = llama_sampler_empty_backend_init; + res->iface->backend_accept = llama_sampler_empty_backend_accept; + res->iface->backend_apply = llama_sampler_empty_backend_apply; + res->iface->backend_set_input = llama_sampler_empty_backend_set_input; + } + + return res; } // wrapper for test-sampling.cpp @@ -2854,6 +2971,8 @@ static void llama_sampler_logit_bias_backend_apply( // Add the sparse logit logit_bias to the logits struct ggml_tensor * logit_biased = ggml_add_inplace(ctx, data->logits, sctx->inp_logit_bias); + data->logits = logit_biased; + ggml_build_forward_expand(gf, logit_biased); }