llama: introduce support for model-embedded sampling parameters (#17120)
This commit is contained in:
parent
3d07caa99b
commit
877566d512
|
|
@ -1232,6 +1232,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
[](common_params & params, const std::string & value) {
|
[](common_params & params, const std::string & value) {
|
||||||
const auto sampler_names = string_split<std::string>(value, ';');
|
const auto sampler_names = string_split<std::string>(value, ';');
|
||||||
params.sampling.samplers = common_sampler_types_from_names(sampler_names, true);
|
params.sampling.samplers = common_sampler_types_from_names(sampler_names, true);
|
||||||
|
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS;
|
||||||
}
|
}
|
||||||
).set_sparam());
|
).set_sparam());
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
|
|
@ -1261,6 +1262,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
[](common_params & params, const std::string & value) {
|
[](common_params & params, const std::string & value) {
|
||||||
params.sampling.temp = std::stof(value);
|
params.sampling.temp = std::stof(value);
|
||||||
params.sampling.temp = std::max(params.sampling.temp, 0.0f);
|
params.sampling.temp = std::max(params.sampling.temp, 0.0f);
|
||||||
|
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TEMP;
|
||||||
}
|
}
|
||||||
).set_sparam());
|
).set_sparam());
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
|
|
@ -1268,6 +1270,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
string_format("top-k sampling (default: %d, 0 = disabled)", params.sampling.top_k),
|
string_format("top-k sampling (default: %d, 0 = disabled)", params.sampling.top_k),
|
||||||
[](common_params & params, int value) {
|
[](common_params & params, int value) {
|
||||||
params.sampling.top_k = value;
|
params.sampling.top_k = value;
|
||||||
|
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K;
|
||||||
}
|
}
|
||||||
).set_sparam());
|
).set_sparam());
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
|
|
@ -1275,6 +1278,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p),
|
string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p),
|
||||||
[](common_params & params, const std::string & value) {
|
[](common_params & params, const std::string & value) {
|
||||||
params.sampling.top_p = std::stof(value);
|
params.sampling.top_p = std::stof(value);
|
||||||
|
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P;
|
||||||
}
|
}
|
||||||
).set_sparam());
|
).set_sparam());
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
|
|
@ -1282,6 +1286,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
string_format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sampling.min_p),
|
string_format("min-p sampling (default: %.1f, 0.0 = disabled)", (double)params.sampling.min_p),
|
||||||
[](common_params & params, const std::string & value) {
|
[](common_params & params, const std::string & value) {
|
||||||
params.sampling.min_p = std::stof(value);
|
params.sampling.min_p = std::stof(value);
|
||||||
|
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P;
|
||||||
}
|
}
|
||||||
).set_sparam());
|
).set_sparam());
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
|
|
@ -1296,6 +1301,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability),
|
string_format("xtc probability (default: %.1f, 0.0 = disabled)", (double)params.sampling.xtc_probability),
|
||||||
[](common_params & params, const std::string & value) {
|
[](common_params & params, const std::string & value) {
|
||||||
params.sampling.xtc_probability = std::stof(value);
|
params.sampling.xtc_probability = std::stof(value);
|
||||||
|
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY;
|
||||||
}
|
}
|
||||||
).set_sparam());
|
).set_sparam());
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
|
|
@ -1303,6 +1309,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sampling.xtc_threshold),
|
string_format("xtc threshold (default: %.1f, 1.0 = disabled)", (double)params.sampling.xtc_threshold),
|
||||||
[](common_params & params, const std::string & value) {
|
[](common_params & params, const std::string & value) {
|
||||||
params.sampling.xtc_threshold = std::stof(value);
|
params.sampling.xtc_threshold = std::stof(value);
|
||||||
|
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD;
|
||||||
}
|
}
|
||||||
).set_sparam());
|
).set_sparam());
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
|
|
@ -1321,6 +1328,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
}
|
}
|
||||||
params.sampling.penalty_last_n = value;
|
params.sampling.penalty_last_n = value;
|
||||||
params.sampling.n_prev = std::max(params.sampling.n_prev, params.sampling.penalty_last_n);
|
params.sampling.n_prev = std::max(params.sampling.n_prev, params.sampling.penalty_last_n);
|
||||||
|
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N;
|
||||||
}
|
}
|
||||||
).set_sparam());
|
).set_sparam());
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
|
|
@ -1328,6 +1336,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
string_format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sampling.penalty_repeat),
|
string_format("penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)params.sampling.penalty_repeat),
|
||||||
[](common_params & params, const std::string & value) {
|
[](common_params & params, const std::string & value) {
|
||||||
params.sampling.penalty_repeat = std::stof(value);
|
params.sampling.penalty_repeat = std::stof(value);
|
||||||
|
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT;
|
||||||
}
|
}
|
||||||
).set_sparam());
|
).set_sparam());
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
|
|
@ -1425,6 +1434,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
"(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sampling.mirostat),
|
"(default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)", params.sampling.mirostat),
|
||||||
[](common_params & params, int value) {
|
[](common_params & params, int value) {
|
||||||
params.sampling.mirostat = value;
|
params.sampling.mirostat = value;
|
||||||
|
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT;
|
||||||
}
|
}
|
||||||
).set_sparam());
|
).set_sparam());
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
|
|
@ -1432,6 +1442,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
string_format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sampling.mirostat_eta),
|
string_format("Mirostat learning rate, parameter eta (default: %.1f)", (double)params.sampling.mirostat_eta),
|
||||||
[](common_params & params, const std::string & value) {
|
[](common_params & params, const std::string & value) {
|
||||||
params.sampling.mirostat_eta = std::stof(value);
|
params.sampling.mirostat_eta = std::stof(value);
|
||||||
|
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA;
|
||||||
}
|
}
|
||||||
).set_sparam());
|
).set_sparam());
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
|
|
@ -1439,6 +1450,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
||||||
string_format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sampling.mirostat_tau),
|
string_format("Mirostat target entropy, parameter tau (default: %.1f)", (double)params.sampling.mirostat_tau),
|
||||||
[](common_params & params, const std::string & value) {
|
[](common_params & params, const std::string & value) {
|
||||||
params.sampling.mirostat_tau = std::stof(value);
|
params.sampling.mirostat_tau = std::stof(value);
|
||||||
|
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU;
|
||||||
}
|
}
|
||||||
).set_sparam());
|
).set_sparam());
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "log.h"
|
#include "log.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
#include "sampling.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <cinttypes>
|
#include <cinttypes>
|
||||||
|
|
@ -949,6 +950,58 @@ std::vector<common_file_info> fs_list_files(const std::string & path) {
|
||||||
// Model utils
|
// Model utils
|
||||||
//
|
//
|
||||||
|
|
||||||
|
static inline void common_init_sampler_from_model(
|
||||||
|
const llama_model * model,
|
||||||
|
common_params_sampling & sparams) {
|
||||||
|
|
||||||
|
const uint64_t config = sparams.user_sampling_config;
|
||||||
|
|
||||||
|
auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) {
|
||||||
|
if (config & user_config) return;
|
||||||
|
|
||||||
|
char buf[64] = {0};
|
||||||
|
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
|
||||||
|
char * end = nullptr;
|
||||||
|
int32_t v = strtol(buf, &end, 10);
|
||||||
|
if (end && end != buf) dst = v;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto get_float = [&](const char * key, float & dst, uint64_t user_config) {
|
||||||
|
if (config & user_config) return;
|
||||||
|
|
||||||
|
char buf[128] = {0};
|
||||||
|
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
|
||||||
|
char * end = nullptr;
|
||||||
|
float v = strtof(buf, &end);
|
||||||
|
if (end && end != buf) dst = v;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Sampling sequence
|
||||||
|
if (!(config & common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS)) {
|
||||||
|
char buf[512] = {0};
|
||||||
|
if (llama_model_meta_val_str(model, llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE), buf, sizeof(buf)) > 0) {
|
||||||
|
const std::vector<std::string> sampler_names = string_split<std::string>(std::string(buf), ';');
|
||||||
|
if (!sampler_names.empty()) {
|
||||||
|
sparams.samplers = common_sampler_types_from_names(sampler_names, true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_K), sparams.top_k, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K);
|
||||||
|
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_P), sparams.top_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P);
|
||||||
|
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIN_P), sparams.min_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P);
|
||||||
|
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY), sparams.xtc_probability, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY);
|
||||||
|
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD), sparams.xtc_threshold, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD);
|
||||||
|
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TEMP), sparams.temp, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TEMP);
|
||||||
|
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N), sparams.penalty_last_n, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N);
|
||||||
|
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT), sparams.penalty_repeat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT);
|
||||||
|
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT), sparams.mirostat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT);
|
||||||
|
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU), sparams.mirostat_tau, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU);
|
||||||
|
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA);
|
||||||
|
}
|
||||||
|
|
||||||
struct common_init_result common_init_from_params(common_params & params) {
|
struct common_init_result common_init_from_params(common_params & params) {
|
||||||
common_init_result iparams;
|
common_init_result iparams;
|
||||||
auto mparams = common_model_params_to_llama(params);
|
auto mparams = common_model_params_to_llama(params);
|
||||||
|
|
@ -960,6 +1013,8 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||||
return iparams;
|
return iparams;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
common_init_sampler_from_model(model, params.sampling);
|
||||||
|
|
||||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||||
|
|
||||||
auto cparams = common_context_params_to_llama(params);
|
auto cparams = common_context_params_to_llama(params);
|
||||||
|
|
|
||||||
|
|
@ -140,6 +140,22 @@ struct common_grammar_trigger {
|
||||||
llama_token token = LLAMA_TOKEN_NULL;
|
llama_token token = LLAMA_TOKEN_NULL;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum common_params_sampling_config : uint64_t {
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS = 1 << 0,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_TOP_K = 1 << 1,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_TOP_P = 1 << 2,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_MIN_P = 1 << 3,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY = 1 << 4,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD = 1 << 5,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_TEMP = 1 << 6,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N = 1 << 7,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT = 1 << 8,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT = 1 << 9,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU = 1 << 10,
|
||||||
|
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA = 1 << 11,
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
// sampling parameters
|
// sampling parameters
|
||||||
struct common_params_sampling {
|
struct common_params_sampling {
|
||||||
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
|
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
|
||||||
|
|
@ -172,6 +188,8 @@ struct common_params_sampling {
|
||||||
bool no_perf = false; // disable performance metrics
|
bool no_perf = false; // disable performance metrics
|
||||||
bool timing_per_token = false;
|
bool timing_per_token = false;
|
||||||
|
|
||||||
|
uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers
|
||||||
|
|
||||||
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
|
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,20 @@ class Keys:
|
||||||
ALIGNMENT = "general.alignment"
|
ALIGNMENT = "general.alignment"
|
||||||
FILE_TYPE = "general.file_type"
|
FILE_TYPE = "general.file_type"
|
||||||
|
|
||||||
|
# Recommended Sampler Parameters
|
||||||
|
SAMPLING_SEQUENCE = "general.sampling.sequence"
|
||||||
|
SAMPLING_TOP_K = "general.sampling.top_k"
|
||||||
|
SAMPLING_TOP_P = "general.sampling.top_p"
|
||||||
|
SAMPLING_MIN_P = "general.sampling.min_p"
|
||||||
|
SAMPLING_XTC_PROBABILITY = "general.sampling.xtc_probability"
|
||||||
|
SAMPLING_XTC_THRESHOLD = "general.sampling.xtc_threshold"
|
||||||
|
SAMPLING_TEMP = "general.sampling.temp"
|
||||||
|
SAMPLING_PENALTY_LAST_N = "general.sampling.penalty_last_n"
|
||||||
|
SAMPLING_PENALTY_REPEAT = "general.sampling.penalty_repeat"
|
||||||
|
SAMPLING_MIROSTAT = "general.sampling.mirostat"
|
||||||
|
SAMPLING_MIROSTAT_TAU = "general.sampling.mirostat_tau"
|
||||||
|
SAMPLING_MIROSTAT_ETA = "general.sampling.mirostat_eta"
|
||||||
|
|
||||||
# Authorship Metadata
|
# Authorship Metadata
|
||||||
NAME = "general.name"
|
NAME = "general.name"
|
||||||
AUTHOR = "general.author"
|
AUTHOR = "general.author"
|
||||||
|
|
|
||||||
|
|
@ -496,6 +496,42 @@ class GGUFWriter:
|
||||||
def add_file_type(self, ftype: int) -> None:
|
def add_file_type(self, ftype: int) -> None:
|
||||||
self.add_uint32(Keys.General.FILE_TYPE, ftype)
|
self.add_uint32(Keys.General.FILE_TYPE, ftype)
|
||||||
|
|
||||||
|
def add_sampling_sequence(self, sequence: str) -> None:
|
||||||
|
self.add_string(Keys.General.SAMPLING_SEQUENCE, sequence)
|
||||||
|
|
||||||
|
def add_sampling_top_k(self, top_k: int) -> None:
|
||||||
|
self.add_int32(Keys.General.SAMPLING_TOP_K, top_k)
|
||||||
|
|
||||||
|
def add_sampling_top_p(self, top_p: float) -> None:
|
||||||
|
self.add_float32(Keys.General.SAMPLING_TOP_P, top_p)
|
||||||
|
|
||||||
|
def add_sampling_min_p(self, min_p: float) -> None:
|
||||||
|
self.add_float32(Keys.General.SAMPLING_MIN_P, min_p)
|
||||||
|
|
||||||
|
def add_sampling_xtc_probability(self, xtc_probability: float) -> None:
|
||||||
|
self.add_float32(Keys.General.SAMPLING_XTC_PROBABILITY, xtc_probability)
|
||||||
|
|
||||||
|
def add_sampling_xtc_threshold(self, xtc_threshold: float) -> None:
|
||||||
|
self.add_float32(Keys.General.SAMPLING_XTC_THRESHOLD, xtc_threshold)
|
||||||
|
|
||||||
|
def add_sampling_temp(self, temp: float) -> None:
|
||||||
|
self.add_float32(Keys.General.SAMPLING_TEMP, temp)
|
||||||
|
|
||||||
|
def add_sampling_penalty_last_n(self, penalty_last_n: int) -> None:
|
||||||
|
self.add_int32(Keys.General.SAMPLING_PENALTY_LAST_N, penalty_last_n)
|
||||||
|
|
||||||
|
def add_sampling_penalty_repeat(self, penalty_repeat: float) -> None:
|
||||||
|
self.add_float32(Keys.General.SAMPLING_PENALTY_REPEAT, penalty_repeat)
|
||||||
|
|
||||||
|
def add_sampling_mirostat(self, mirostat: int) -> None:
|
||||||
|
self.add_int32(Keys.General.SAMPLING_MIROSTAT, mirostat)
|
||||||
|
|
||||||
|
def add_sampling_mirostat_tau(self, mirostat_tau: float) -> None:
|
||||||
|
self.add_float32(Keys.General.SAMPLING_MIROSTAT_TAU, mirostat_tau)
|
||||||
|
|
||||||
|
def add_sampling_mirostat_eta(self, mirostat_eta: float) -> None:
|
||||||
|
self.add_float32(Keys.General.SAMPLING_MIROSTAT_ETA, mirostat_eta)
|
||||||
|
|
||||||
def add_name(self, name: str) -> None:
|
def add_name(self, name: str) -> None:
|
||||||
self.add_string(Keys.General.NAME, name)
|
self.add_string(Keys.General.NAME, name)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,20 @@ logger = logging.getLogger("metadata")
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Metadata:
|
class Metadata:
|
||||||
|
# Recommended Sampler Parameters to be written to GGUF KV Store
|
||||||
|
sampling_sequence: Optional[str] = None
|
||||||
|
sampling_top_k: Optional[int] = None
|
||||||
|
sampling_top_p: Optional[float] = None
|
||||||
|
sampling_min_p: Optional[float] = None
|
||||||
|
sampling_xtc_probability: Optional[float] = None
|
||||||
|
sampling_xtc_threshold: Optional[float] = None
|
||||||
|
sampling_temp: Optional[float] = None
|
||||||
|
sampling_penalty_last_n: Optional[int] = None
|
||||||
|
sampling_penalty_repeat: Optional[float] = None
|
||||||
|
sampling_mirostat: Optional[int] = None
|
||||||
|
sampling_mirostat_tau: Optional[float] = None
|
||||||
|
sampling_mirostat_eta: Optional[float] = None
|
||||||
|
|
||||||
# Authorship Metadata to be written to GGUF KV Store
|
# Authorship Metadata to be written to GGUF KV Store
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
author: Optional[str] = None
|
author: Optional[str] = None
|
||||||
|
|
@ -54,15 +68,43 @@ class Metadata:
|
||||||
|
|
||||||
model_card = Metadata.load_model_card(model_path)
|
model_card = Metadata.load_model_card(model_path)
|
||||||
hf_params = Metadata.load_hf_parameters(model_path)
|
hf_params = Metadata.load_hf_parameters(model_path)
|
||||||
|
gen_config = Metadata.load_generation_config(model_path)
|
||||||
# TODO: load adapter_config.json when possible, it usually contains the base model of the LoRA adapter
|
# TODO: load adapter_config.json when possible, it usually contains the base model of the LoRA adapter
|
||||||
|
|
||||||
# heuristics
|
# heuristics
|
||||||
metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params)
|
metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params)
|
||||||
|
|
||||||
|
if gen_config:
|
||||||
|
metadata.sampling_sequence = gen_config.get("sequence", metadata.sampling_sequence)
|
||||||
|
metadata.sampling_top_k = gen_config.get("top_k", metadata.sampling_top_k)
|
||||||
|
metadata.sampling_top_p = gen_config.get("top_p", metadata.sampling_top_p)
|
||||||
|
metadata.sampling_min_p = gen_config.get("min_p", metadata.sampling_min_p)
|
||||||
|
metadata.sampling_xtc_probability = gen_config.get("xtc_probability", metadata.sampling_xtc_probability)
|
||||||
|
metadata.sampling_xtc_threshold = gen_config.get("xtc_threshold", metadata.sampling_xtc_threshold)
|
||||||
|
metadata.sampling_temp = gen_config.get("temperature", metadata.sampling_temp)
|
||||||
|
metadata.sampling_penalty_last_n = gen_config.get("penalty_last_n", metadata.sampling_penalty_last_n)
|
||||||
|
metadata.sampling_penalty_repeat = gen_config.get("penalty_repeat", metadata.sampling_penalty_repeat)
|
||||||
|
metadata.sampling_mirostat = gen_config.get("mirostat", metadata.sampling_mirostat)
|
||||||
|
metadata.sampling_mirostat_tau = gen_config.get("mirostat_tau", metadata.sampling_mirostat_tau)
|
||||||
|
metadata.sampling_mirostat_eta = gen_config.get("mirostat_eta", metadata.sampling_mirostat_eta)
|
||||||
|
|
||||||
# Metadata Override File Provided
|
# Metadata Override File Provided
|
||||||
# This is based on LLM_KV_NAMES mapping in llama.cpp
|
# This is based on LLM_KV_NAMES mapping in llama.cpp
|
||||||
metadata_override = Metadata.load_metadata_override(metadata_override_path)
|
metadata_override = Metadata.load_metadata_override(metadata_override_path)
|
||||||
|
|
||||||
|
metadata.sampling_sequence = metadata_override.get(Keys.General.SAMPLING_SEQUENCE, metadata.sampling_sequence)
|
||||||
|
metadata.sampling_top_k = metadata_override.get(Keys.General.SAMPLING_TOP_K, metadata.sampling_top_k)
|
||||||
|
metadata.sampling_top_p = metadata_override.get(Keys.General.SAMPLING_TOP_P, metadata.sampling_top_p)
|
||||||
|
metadata.sampling_min_p = metadata_override.get(Keys.General.SAMPLING_MIN_P, metadata.sampling_min_p)
|
||||||
|
metadata.sampling_xtc_probability = metadata_override.get(Keys.General.SAMPLING_XTC_PROBABILITY, metadata.sampling_xtc_probability)
|
||||||
|
metadata.sampling_xtc_threshold = metadata_override.get(Keys.General.SAMPLING_XTC_THRESHOLD, metadata.sampling_xtc_threshold)
|
||||||
|
metadata.sampling_temp = metadata_override.get(Keys.General.SAMPLING_TEMP, metadata.sampling_temp)
|
||||||
|
metadata.sampling_penalty_last_n = metadata_override.get(Keys.General.SAMPLING_PENALTY_LAST_N, metadata.sampling_penalty_last_n)
|
||||||
|
metadata.sampling_penalty_repeat = metadata_override.get(Keys.General.SAMPLING_PENALTY_REPEAT, metadata.sampling_penalty_repeat)
|
||||||
|
metadata.sampling_mirostat = metadata_override.get(Keys.General.SAMPLING_MIROSTAT, metadata.sampling_mirostat)
|
||||||
|
metadata.sampling_mirostat_tau = metadata_override.get(Keys.General.SAMPLING_MIROSTAT_TAU, metadata.sampling_mirostat_tau)
|
||||||
|
metadata.sampling_mirostat_eta = metadata_override.get(Keys.General.SAMPLING_MIROSTAT_ETA, metadata.sampling_mirostat_eta)
|
||||||
|
|
||||||
metadata.name = metadata_override.get(Keys.General.NAME, metadata.name)
|
metadata.name = metadata_override.get(Keys.General.NAME, metadata.name)
|
||||||
metadata.author = metadata_override.get(Keys.General.AUTHOR, metadata.author)
|
metadata.author = metadata_override.get(Keys.General.AUTHOR, metadata.author)
|
||||||
metadata.version = metadata_override.get(Keys.General.VERSION, metadata.version)
|
metadata.version = metadata_override.get(Keys.General.VERSION, metadata.version)
|
||||||
|
|
@ -172,6 +214,23 @@ class Metadata:
|
||||||
with open(config_path, "r", encoding="utf-8") as f:
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_generation_config(model_path: Optional[Path] = None) -> dict[str, Any]:
|
||||||
|
if model_path is None or not model_path.is_dir():
|
||||||
|
return {}
|
||||||
|
|
||||||
|
generation_config_path = model_path / "generation_config.json"
|
||||||
|
|
||||||
|
if not generation_config_path.is_file():
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(generation_config_path, "r", encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
except (json.JSONDecodeError, IOError):
|
||||||
|
# not all models have valid generation_config.json
|
||||||
|
return {}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def id_to_title(string):
|
def id_to_title(string):
|
||||||
# Convert capitalization into title form unless acronym or version number
|
# Convert capitalization into title form unless acronym or version number
|
||||||
|
|
@ -546,6 +605,32 @@ class Metadata:
|
||||||
|
|
||||||
def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter):
|
def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter):
|
||||||
assert self.name is not None
|
assert self.name is not None
|
||||||
|
|
||||||
|
if self.sampling_sequence is not None:
|
||||||
|
gguf_writer.add_sampling_sequence(self.sampling_sequence)
|
||||||
|
if self.sampling_top_k is not None:
|
||||||
|
gguf_writer.add_sampling_top_k(self.sampling_top_k)
|
||||||
|
if self.sampling_top_p is not None:
|
||||||
|
gguf_writer.add_sampling_top_p(self.sampling_top_p)
|
||||||
|
if self.sampling_min_p is not None:
|
||||||
|
gguf_writer.add_sampling_min_p(self.sampling_min_p)
|
||||||
|
if self.sampling_xtc_probability is not None:
|
||||||
|
gguf_writer.add_sampling_xtc_probability(self.sampling_xtc_probability)
|
||||||
|
if self.sampling_xtc_threshold is not None:
|
||||||
|
gguf_writer.add_sampling_xtc_threshold(self.sampling_xtc_threshold)
|
||||||
|
if self.sampling_temp is not None:
|
||||||
|
gguf_writer.add_sampling_temp(self.sampling_temp)
|
||||||
|
if self.sampling_penalty_last_n is not None:
|
||||||
|
gguf_writer.add_sampling_penalty_last_n(self.sampling_penalty_last_n)
|
||||||
|
if self.sampling_penalty_repeat is not None:
|
||||||
|
gguf_writer.add_sampling_penalty_repeat(self.sampling_penalty_repeat)
|
||||||
|
if self.sampling_mirostat is not None:
|
||||||
|
gguf_writer.add_sampling_mirostat(self.sampling_mirostat)
|
||||||
|
if self.sampling_mirostat_tau is not None:
|
||||||
|
gguf_writer.add_sampling_mirostat_tau(self.sampling_mirostat_tau)
|
||||||
|
if self.sampling_mirostat_eta is not None:
|
||||||
|
gguf_writer.add_sampling_mirostat_eta(self.sampling_mirostat_eta)
|
||||||
|
|
||||||
gguf_writer.add_name(self.name)
|
gguf_writer.add_name(self.name)
|
||||||
|
|
||||||
if self.author is not None:
|
if self.author is not None:
|
||||||
|
|
|
||||||
|
|
@ -246,6 +246,21 @@ extern "C" {
|
||||||
LLAMA_KV_OVERRIDE_TYPE_STR,
|
LLAMA_KV_OVERRIDE_TYPE_STR,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum llama_model_meta_key {
|
||||||
|
LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE,
|
||||||
|
LLAMA_MODEL_META_KEY_SAMPLING_TOP_K,
|
||||||
|
LLAMA_MODEL_META_KEY_SAMPLING_TOP_P,
|
||||||
|
LLAMA_MODEL_META_KEY_SAMPLING_MIN_P,
|
||||||
|
LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY,
|
||||||
|
LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD,
|
||||||
|
LLAMA_MODEL_META_KEY_SAMPLING_TEMP,
|
||||||
|
LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N,
|
||||||
|
LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT,
|
||||||
|
LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT,
|
||||||
|
LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU,
|
||||||
|
LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA,
|
||||||
|
};
|
||||||
|
|
||||||
struct llama_model_kv_override {
|
struct llama_model_kv_override {
|
||||||
enum llama_model_kv_override_type tag;
|
enum llama_model_kv_override_type tag;
|
||||||
|
|
||||||
|
|
@ -518,6 +533,9 @@ extern "C" {
|
||||||
// Get the number of metadata key/value pairs
|
// Get the number of metadata key/value pairs
|
||||||
LLAMA_API int32_t llama_model_meta_count(const struct llama_model * model);
|
LLAMA_API int32_t llama_model_meta_count(const struct llama_model * model);
|
||||||
|
|
||||||
|
// Get sampling metadata key name. Returns nullptr if the key is invalid
|
||||||
|
LLAMA_API const char * llama_model_meta_key_str(enum llama_model_meta_key key);
|
||||||
|
|
||||||
// Get metadata key name by index
|
// Get metadata key name by index
|
||||||
LLAMA_API int32_t llama_model_meta_key_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size);
|
LLAMA_API int32_t llama_model_meta_key_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -119,6 +119,18 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||||
{ LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" },
|
{ LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" },
|
||||||
{ LLM_KV_GENERAL_ALIGNMENT, "general.alignment" },
|
{ LLM_KV_GENERAL_ALIGNMENT, "general.alignment" },
|
||||||
{ LLM_KV_GENERAL_FILE_TYPE, "general.file_type" },
|
{ LLM_KV_GENERAL_FILE_TYPE, "general.file_type" },
|
||||||
|
{ LLM_KV_GENERAL_SAMPLING_SEQUENCE, "general.sampling.sequence" },
|
||||||
|
{ LLM_KV_GENERAL_SAMPLING_TOP_K, "general.sampling.top_k" },
|
||||||
|
{ LLM_KV_GENERAL_SAMPLING_TOP_P, "general.sampling.top_p" },
|
||||||
|
{ LLM_KV_GENERAL_SAMPLING_MIN_P, "general.sampling.min_p" },
|
||||||
|
{ LLM_KV_GENERAL_SAMPLING_XTC_PROBABILITY, "general.sampling.xtc_probability" },
|
||||||
|
{ LLM_KV_GENERAL_SAMPLING_XTC_THRESHOLD, "general.sampling.xtc_threshold" },
|
||||||
|
{ LLM_KV_GENERAL_SAMPLING_TEMP, "general.sampling.temp" },
|
||||||
|
{ LLM_KV_GENERAL_SAMPLING_PENALTY_LAST_N, "general.sampling.penalty_last_n" },
|
||||||
|
{ LLM_KV_GENERAL_SAMPLING_PENALTY_REPEAT, "general.sampling.penalty_repeat" },
|
||||||
|
{ LLM_KV_GENERAL_SAMPLING_MIROSTAT, "general.sampling.mirostat" },
|
||||||
|
{ LLM_KV_GENERAL_SAMPLING_MIROSTAT_TAU, "general.sampling.mirostat_tau" },
|
||||||
|
{ LLM_KV_GENERAL_SAMPLING_MIROSTAT_ETA, "general.sampling.mirostat_eta" },
|
||||||
{ LLM_KV_GENERAL_NAME, "general.name" },
|
{ LLM_KV_GENERAL_NAME, "general.name" },
|
||||||
{ LLM_KV_GENERAL_AUTHOR, "general.author" },
|
{ LLM_KV_GENERAL_AUTHOR, "general.author" },
|
||||||
{ LLM_KV_GENERAL_VERSION, "general.version" },
|
{ LLM_KV_GENERAL_VERSION, "general.version" },
|
||||||
|
|
|
||||||
|
|
@ -123,6 +123,18 @@ enum llm_kv {
|
||||||
LLM_KV_GENERAL_QUANTIZATION_VERSION,
|
LLM_KV_GENERAL_QUANTIZATION_VERSION,
|
||||||
LLM_KV_GENERAL_ALIGNMENT,
|
LLM_KV_GENERAL_ALIGNMENT,
|
||||||
LLM_KV_GENERAL_FILE_TYPE,
|
LLM_KV_GENERAL_FILE_TYPE,
|
||||||
|
LLM_KV_GENERAL_SAMPLING_SEQUENCE,
|
||||||
|
LLM_KV_GENERAL_SAMPLING_TOP_K,
|
||||||
|
LLM_KV_GENERAL_SAMPLING_TOP_P,
|
||||||
|
LLM_KV_GENERAL_SAMPLING_MIN_P,
|
||||||
|
LLM_KV_GENERAL_SAMPLING_XTC_PROBABILITY,
|
||||||
|
LLM_KV_GENERAL_SAMPLING_XTC_THRESHOLD,
|
||||||
|
LLM_KV_GENERAL_SAMPLING_TEMP,
|
||||||
|
LLM_KV_GENERAL_SAMPLING_PENALTY_LAST_N,
|
||||||
|
LLM_KV_GENERAL_SAMPLING_PENALTY_REPEAT,
|
||||||
|
LLM_KV_GENERAL_SAMPLING_MIROSTAT,
|
||||||
|
LLM_KV_GENERAL_SAMPLING_MIROSTAT_TAU,
|
||||||
|
LLM_KV_GENERAL_SAMPLING_MIROSTAT_ETA,
|
||||||
LLM_KV_GENERAL_NAME,
|
LLM_KV_GENERAL_NAME,
|
||||||
LLM_KV_GENERAL_AUTHOR,
|
LLM_KV_GENERAL_AUTHOR,
|
||||||
LLM_KV_GENERAL_VERSION,
|
LLM_KV_GENERAL_VERSION,
|
||||||
|
|
|
||||||
|
|
@ -7687,6 +7687,24 @@ int32_t llama_model_meta_count(const llama_model * model) {
|
||||||
return (int)model->gguf_kv.size();
|
return (int)model->gguf_kv.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const char * llama_model_meta_key_str(llama_model_meta_key key) {
|
||||||
|
switch (key) {
|
||||||
|
case LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE: return "general.sampling.sequence";
|
||||||
|
case LLAMA_MODEL_META_KEY_SAMPLING_TOP_K: return "general.sampling.top_k";
|
||||||
|
case LLAMA_MODEL_META_KEY_SAMPLING_TOP_P: return "general.sampling.top_p";
|
||||||
|
case LLAMA_MODEL_META_KEY_SAMPLING_MIN_P: return "general.sampling.min_p";
|
||||||
|
case LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY: return "general.sampling.xtc_probability";
|
||||||
|
case LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD: return "general.sampling.xtc_threshold";
|
||||||
|
case LLAMA_MODEL_META_KEY_SAMPLING_TEMP: return "general.sampling.temp";
|
||||||
|
case LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N: return "general.sampling.penalty_last_n";
|
||||||
|
case LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT: return "general.sampling.penalty_repeat";
|
||||||
|
case LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT: return "general.sampling.mirostat";
|
||||||
|
case LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU: return "general.sampling.mirostat_tau";
|
||||||
|
case LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA: return "general.sampling.mirostat_eta";
|
||||||
|
default: return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
int32_t llama_model_meta_key_by_index(const llama_model * model, int i, char * buf, size_t buf_size) {
|
int32_t llama_model_meta_key_by_index(const llama_model * model, int i, char * buf, size_t buf_size) {
|
||||||
if (i < 0 || i >= (int)model->gguf_kv.size()) {
|
if (i < 0 || i >= (int)model->gguf_kv.size()) {
|
||||||
if (buf_size > 0) {
|
if (buf_size > 0) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue