diff --git a/common/sampling.cpp b/common/sampling.cpp index 7a6b7be1e0..07d7153384 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -243,6 +243,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co params.logit_bias.data())); if (params.mirostat == 0) { + // if this flag is set, we will not need to add `dist` at the end of the sampler chain + bool has_distribution_sampler = false; + for (const auto & cnstr : params.samplers) { switch (cnstr) { case COMMON_SAMPLER_TYPE_DRY: @@ -253,7 +256,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co c_breakers.push_back(str.c_str()); } - llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); + llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); } break; case COMMON_SAMPLER_TYPE_TOP_K: @@ -283,11 +286,18 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co case COMMON_SAMPLER_TYPE_PENALTIES: llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); break; + case COMMON_SAMPLER_TYPE_POWER_LAW: + llama_sampler_chain_add(result->chain, llama_sampler_init_power_law (params.power_law_target, params.power_law_target_range, params.power_law_window_size, params.seed)); + has_distribution_sampler = true; + break; default: GGML_ASSERT(false && "unknown sampler type"); } } - llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); + // only add `dist` to the end of the chain if no other distribution samplers were added + if (!has_distribution_sampler) { + llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); + } } else if (params.mirostat == 1) { llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); @@ -586,6 +596,7 @@ std::vector common_sampler_types_from_names(const std::vect { "xtc", COMMON_SAMPLER_TYPE_XTC }, { "infill", COMMON_SAMPLER_TYPE_INFILL }, { "penalties", COMMON_SAMPLER_TYPE_PENALTIES }, + { "power_law", COMMON_SAMPLER_TYPE_POWER_LAW }, }; // since samplers names are written multiple ways @@ -601,6 +612,7 @@ std::vector common_sampler_types_from_names(const std::vect { "typ", COMMON_SAMPLER_TYPE_TYPICAL_P }, { "min-p", COMMON_SAMPLER_TYPE_MIN_P }, { "temp", COMMON_SAMPLER_TYPE_TEMPERATURE }, + { "power-law", COMMON_SAMPLER_TYPE_POWER_LAW }, }; std::vector samplers;