common : refactor common_sampler + grammar logic changes (#17937)

* common : refactor common_sampler + grammar logic changes

* tests : increase max_tokens to get needed response

* batched : fix uninitialized samplers
This commit is contained in:
Georgi Gerganov 2025-12-14 10:11:13 +02:00 committed by GitHub
parent 3238b1400c
commit 254098a279
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 372 additions and 293 deletions

View File

@ -1415,7 +1415,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.sampling.top_k = value; params.sampling.top_k = value;
params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K; params.sampling.user_sampling_config |= common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K;
} }
).set_sparam()); ).set_sparam().set_env("LLAMA_ARG_TOP_K"));
add_opt(common_arg( add_opt(common_arg(
{"--top-p"}, "N", {"--top-p"}, "N",
string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p), string_format("top-p sampling (default: %.1f, 1.0 = disabled)", (double)params.sampling.top_p),

View File

@ -1013,31 +1013,40 @@ bool tty_can_use_colors() {
// Model utils // Model utils
// //
static inline void common_init_sampler_from_model( // TODO: move to common/sampling
static void common_init_sampler_from_model(
const llama_model * model, const llama_model * model,
common_params_sampling & sparams) { common_params_sampling & sparams) {
const uint64_t config = sparams.user_sampling_config; const uint64_t config = sparams.user_sampling_config;
auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) { auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) {
if (config & user_config) return; if (config & user_config) {
return;
}
char buf[64] = {0}; char buf[64] = {0};
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) { if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
char * end = nullptr; char * end = nullptr;
int32_t v = strtol(buf, &end, 10); int32_t v = strtol(buf, &end, 10);
if (end && end != buf) dst = v; if (end && end != buf) {
dst = v;
}
} }
}; };
auto get_float = [&](const char * key, float & dst, uint64_t user_config) { auto get_float = [&](const char * key, float & dst, uint64_t user_config) {
if (config & user_config) return; if (config & user_config) {
return;
}
char buf[128] = {0}; char buf[128] = {0};
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) { if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
char * end = nullptr; char * end = nullptr;
float v = strtof(buf, &end); float v = strtof(buf, &end);
if (end && end != buf) dst = v; if (end && end != buf) {
dst = v;
}
} }
}; };
@ -1065,31 +1074,122 @@ static inline void common_init_sampler_from_model(
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA); get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA);
} }
struct common_init_result common_init_from_params(common_params & params) { struct common_init_result::impl {
common_init_result iparams; impl() = default;
auto mparams = common_model_params_to_llama(params); ~impl() = default;
llama_model_ptr model;
llama_context_ptr context;
std::vector<llama_adapter_lora_ptr> lora;
std::vector<common_sampler_ptr> samplers;
};
common_init_result::common_init_result(common_params & params) :
pimpl(new impl{}) {
const auto mparams = common_model_params_to_llama(params);
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams); llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
if (model == NULL) { if (model == NULL) {
LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n", return;
__func__, params.model.path.c_str());
return iparams;
} }
common_init_sampler_from_model(model, params.sampling); pimpl->model.reset(model);
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
// updates params.sampling
// TODO: fix naming
common_init_sampler_from_model(model, params.sampling);
auto cparams = common_context_params_to_llama(params); auto cparams = common_context_params_to_llama(params);
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
params.sampling.ignore_eos = false;
}
// initialize once
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
if (llama_vocab_is_eog(vocab, i)) {
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(vocab, i).c_str(), -INFINITY);
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
}
}
if (params.sampling.ignore_eos) {
// add EOG biases to the active set of logit biases
params.sampling.logit_bias.insert(
params.sampling.logit_bias.end(),
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
}
//if (params.sampling.penalty_last_n == -1) {
// LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
// params.sampling.penalty_last_n = llama_n_ctx(lctx);
//}
//if (params.sampling.dry_penalty_last_n == -1) {
// LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
// params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
//}
pimpl->samplers.resize(cparams.n_seq_max);
for (int i = 0; i < (int) cparams.n_seq_max; ++i) {
pimpl->samplers[i].reset(common_sampler_init(model, params.sampling));
}
llama_context * lctx = llama_init_from_model(model, cparams); llama_context * lctx = llama_init_from_model(model, cparams);
if (lctx == NULL) { if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n", LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
__func__, params.model.path.c_str()); __func__, params.model.path.c_str());
llama_model_free(model); return;
return iparams;
} }
pimpl->context.reset(lctx);
}
llama_model * common_init_result::model() {
return pimpl->model.get();
}
llama_context * common_init_result::context() {
return pimpl->context.get();
}
common_sampler * common_init_result::sampler(llama_seq_id seq_id) {
return pimpl->samplers[seq_id].get();
}
std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
return pimpl->lora;
}
void common_init_result::free_context() {
pimpl->context.reset();
}
common_init_result_ptr common_init_from_params(common_params & params) {
common_init_result_ptr res(new common_init_result(params));
llama_model * model = res->model();
if (model == NULL) {
LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
__func__, params.model.path.c_str());
return res;
}
llama_context * lctx = res->context();
if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
__func__, params.model.path.c_str());
return res;
}
const llama_vocab * vocab = llama_model_get_vocab(model);
if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) { if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) {
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__); LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
params.ctx_shift = false; params.ctx_shift = false;
@ -1101,10 +1201,7 @@ struct common_init_result common_init_from_params(common_params & params) {
const auto cvec = common_control_vector_load(params.control_vectors); const auto cvec = common_control_vector_load(params.control_vectors);
if (cvec.n_embd == -1) { if (cvec.n_embd == -1) {
llama_free(lctx); return res;
llama_model_free(model);
return iparams;
} }
int err = llama_apply_adapter_cvec( int err = llama_apply_adapter_cvec(
@ -1115,10 +1212,7 @@ struct common_init_result common_init_from_params(common_params & params) {
params.control_vector_layer_start, params.control_vector_layer_start,
params.control_vector_layer_end); params.control_vector_layer_end);
if (err) { if (err) {
llama_free(lctx); return res;
llama_model_free(model);
return iparams;
} }
} }
@ -1142,10 +1236,7 @@ struct common_init_result common_init_from_params(common_params & params) {
} }
if (!ok) { if (!ok) {
llama_free(lctx); return res;
llama_model_free(model);
return iparams;
} }
} }
@ -1155,9 +1246,7 @@ struct common_init_result common_init_from_params(common_params & params) {
lora.reset(llama_adapter_lora_init(model, la.path.c_str())); lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
if (lora == nullptr) { if (lora == nullptr) {
LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str()); LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
llama_free(lctx); return res;
llama_model_free(model);
return iparams;
} }
char buf[1024]; char buf[1024];
@ -1166,43 +1255,13 @@ struct common_init_result common_init_from_params(common_params & params) {
la.task_name = buf; la.task_name = buf;
llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf)); llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf));
la.prompt_prefix = buf; la.prompt_prefix = buf;
iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters res->lora().emplace_back(std::move(lora)); // copy to list of loaded adapters
} }
if (!params.lora_init_without_apply) { if (!params.lora_init_without_apply) {
common_set_adapter_lora(lctx, params.lora_adapters); common_set_adapter_lora(lctx, params.lora_adapters);
} }
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
params.sampling.ignore_eos = false;
}
// initialize once
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
if (llama_vocab_is_eog(vocab, i)) {
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
}
}
if (params.sampling.ignore_eos) {
// add EOG biases to the active set of logit biases
params.sampling.logit_bias.insert(
params.sampling.logit_bias.end(),
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
}
if (params.sampling.penalty_last_n == -1) {
LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
params.sampling.penalty_last_n = llama_n_ctx(lctx);
}
if (params.sampling.dry_penalty_last_n == -1) {
LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
}
if (params.warmup) { if (params.warmup) {
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__); LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
@ -1241,12 +1300,11 @@ struct common_init_result common_init_from_params(common_params & params) {
llama_set_warmup(lctx, false); llama_set_warmup(lctx, false);
} }
iparams.model.reset(model); return res;
iparams.context.reset(lctx);
return iparams;
} }
common_init_result::~common_init_result() = default;
std::string get_model_endpoint() { std::string get_model_endpoint() {
const char * model_endpoint_env = getenv("MODEL_ENDPOINT"); const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
// We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility. // We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
@ -1255,7 +1313,9 @@ std::string get_model_endpoint() {
std::string model_endpoint = "https://huggingface.co/"; std::string model_endpoint = "https://huggingface.co/";
if (endpoint_env) { if (endpoint_env) {
model_endpoint = endpoint_env; model_endpoint = endpoint_env;
if (model_endpoint.back() != '/') model_endpoint += '/'; if (model_endpoint.back() != '/') {
model_endpoint += '/';
}
} }
return model_endpoint; return model_endpoint;
} }

View File

@ -195,7 +195,6 @@ struct common_params_sampling {
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
std::vector<enum common_sampler_type> samplers = { std::vector<enum common_sampler_type> samplers = {
COMMON_SAMPLER_TYPE_PENALTIES, COMMON_SAMPLER_TYPE_PENALTIES,
COMMON_SAMPLER_TYPE_DRY, COMMON_SAMPLER_TYPE_DRY,
@ -216,6 +215,10 @@ struct common_params_sampling {
std::vector<llama_logit_bias> logit_bias; // logit biases to apply std::vector<llama_logit_bias> logit_bias; // logit biases to apply
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
bool has_logit_bias() const {
return !logit_bias.empty();
}
// print the parameters into a string // print the parameters into a string
std::string print() const; std::string print() const;
}; };
@ -669,15 +672,29 @@ bool tty_can_use_colors();
// Model utils // Model utils
// //
// note: defines object's lifetime struct common_sampler;
struct common_init_result {
llama_model_ptr model;
llama_context_ptr context;
std::vector<llama_adapter_lora_ptr> lora; // note: defines the model, context, samplers, ets. lifetimes
struct common_init_result {
common_init_result(common_params & params);
~common_init_result();
llama_model * model();
llama_context * context();
common_sampler * sampler(llama_seq_id seq_id);
std::vector<llama_adapter_lora_ptr> & lora();
void free_context();
private:
struct impl;
std::unique_ptr<impl> pimpl;
}; };
struct common_init_result common_init_from_params(common_params & params); using common_init_result_ptr = std::unique_ptr<common_init_result>;
common_init_result_ptr common_init_from_params(common_params & params);
struct llama_model_params common_model_params_to_llama ( common_params & params); struct llama_model_params common_model_params_to_llama ( common_params & params);
struct llama_context_params common_context_params_to_llama(const common_params & params); struct llama_context_params common_context_params_to_llama(const common_params & params);

View File

@ -104,9 +104,10 @@ struct ring_buffer {
struct common_sampler { struct common_sampler {
common_params_sampling params; common_params_sampling params;
struct llama_sampler * grmr;
struct llama_sampler * chain; struct llama_sampler * chain;
bool grammar;
ring_buffer<llama_token> prev; ring_buffer<llama_token> prev;
std::vector<llama_token_data> cur; std::vector<llama_token_data> cur;
@ -116,7 +117,6 @@ struct common_sampler {
void reset() { void reset() {
prev.clear(); prev.clear();
llama_sampler_reset(grmr);
llama_sampler_reset(chain); llama_sampler_reset(chain);
} }
@ -167,10 +167,15 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
lparams.no_perf = params.no_perf; lparams.no_perf = params.no_perf;
struct llama_sampler * grmr; llama_sampler * chain = llama_sampler_chain_init(lparams);
bool grammar = false;
std::vector<llama_sampler *> samplers;
if (params.grammar.compare(0, 11, "%llguidance") == 0) { if (params.grammar.compare(0, 11, "%llguidance") == 0) {
#ifdef LLAMA_USE_LLGUIDANCE #ifdef LLAMA_USE_LLGUIDANCE
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()); samplers.push_back(llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()));
grammar = true;
#else #else
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled"); GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
#endif // LLAMA_USE_LLGUIDANCE #endif // LLAMA_USE_LLGUIDANCE
@ -217,30 +222,23 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
trigger_patterns_c.push_back(regex.c_str()); trigger_patterns_c.push_back(regex.c_str());
} }
grmr = params.grammar_lazy if (!params.grammar.empty()) {
? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root", if (params.grammar_lazy) {
samplers.push_back(
llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
trigger_patterns_c.data(), trigger_patterns_c.size(), trigger_patterns_c.data(), trigger_patterns_c.size(),
trigger_tokens.data(), trigger_tokens.size()) trigger_tokens.data(), trigger_tokens.size()));
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"); } else {
if (!grmr) { samplers.push_back(llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"));
return nullptr; }
grammar = true;
} }
} }
auto * result = new common_sampler { if (params.has_logit_bias()) {
/* .params = */ params, samplers.push_back(llama_sampler_init_logit_bias(llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data()));
/* .grmr = */ grmr, }
/* .chain = */ llama_sampler_chain_init(lparams),
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
/* .cur = */ {},
/* .cur_p = */ {},
};
llama_sampler_chain_add(result->chain,
llama_sampler_init_logit_bias(
llama_vocab_n_tokens(vocab),
params.logit_bias.size(),
params.logit_bias.data()));
if (params.mirostat == 0) { if (params.mirostat == 0) {
for (const auto & cnstr : params.samplers) { for (const auto & cnstr : params.samplers) {
@ -253,58 +251,70 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
c_breakers.push_back(str.c_str()); c_breakers.push_back(str.c_str());
} }
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); samplers.push_back(llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
} }
break; break;
case COMMON_SAMPLER_TYPE_TOP_K: case COMMON_SAMPLER_TYPE_TOP_K:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); samplers.push_back(llama_sampler_init_top_k (params.top_k));
break; break;
case COMMON_SAMPLER_TYPE_TOP_P: case COMMON_SAMPLER_TYPE_TOP_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); samplers.push_back(llama_sampler_init_top_p (params.top_p, params.min_keep));
break; break;
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
llama_sampler_chain_add(result->chain, llama_sampler_init_top_n_sigma (params.top_n_sigma)); samplers.push_back(llama_sampler_init_top_n_sigma(params.top_n_sigma));
break; break;
case COMMON_SAMPLER_TYPE_MIN_P: case COMMON_SAMPLER_TYPE_MIN_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); samplers.push_back(llama_sampler_init_min_p (params.min_p, params.min_keep));
break; break;
case COMMON_SAMPLER_TYPE_XTC: case COMMON_SAMPLER_TYPE_XTC:
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed)); samplers.push_back(llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
break; break;
case COMMON_SAMPLER_TYPE_TYPICAL_P: case COMMON_SAMPLER_TYPE_TYPICAL_P:
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); samplers.push_back(llama_sampler_init_typical (params.typ_p, params.min_keep));
break; break;
case COMMON_SAMPLER_TYPE_TEMPERATURE: case COMMON_SAMPLER_TYPE_TEMPERATURE:
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); samplers.push_back(llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
break; break;
case COMMON_SAMPLER_TYPE_INFILL: case COMMON_SAMPLER_TYPE_INFILL:
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab)); samplers.push_back(llama_sampler_init_infill (vocab));
break; break;
case COMMON_SAMPLER_TYPE_PENALTIES: case COMMON_SAMPLER_TYPE_PENALTIES:
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); samplers.push_back(llama_sampler_init_penalties (params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
break; break;
default: default:
GGML_ASSERT(false && "unknown sampler type"); GGML_ASSERT(false && "unknown sampler type");
} }
} }
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
samplers.push_back(llama_sampler_init_dist(params.seed));
} else if (params.mirostat == 1) { } else if (params.mirostat == 1) {
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); samplers.push_back(llama_sampler_init_temp(params.temp));
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); samplers.push_back(llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
} else if (params.mirostat == 2) { } else if (params.mirostat == 2) {
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); samplers.push_back(llama_sampler_init_temp(params.temp));
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta)); samplers.push_back(llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
} else { } else {
GGML_ASSERT(false && "unknown mirostat version"); GGML_ASSERT(false && "unknown mirostat version");
} }
for (auto * smpl : samplers) {
llama_sampler_chain_add(chain, smpl);
}
auto * result = new common_sampler {
/* .params = */ params,
/* .chain = */ chain,
/* .grammar = */ grammar,
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
/* .cur = */ {},
/* .cur_p = */ {},
};
return result; return result;
} }
void common_sampler_free(struct common_sampler * gsmpl) { void common_sampler_free(struct common_sampler * gsmpl) {
if (gsmpl) { if (gsmpl) {
llama_sampler_free(gsmpl->grmr);
llama_sampler_free(gsmpl->chain); llama_sampler_free(gsmpl->chain);
delete gsmpl; delete gsmpl;
@ -314,11 +324,24 @@ void common_sampler_free(struct common_sampler * gsmpl) {
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) { void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
const auto tm = gsmpl->tm(); const auto tm = gsmpl->tm();
if (accept_grammar) { if (gsmpl->grammar) {
llama_sampler_accept(gsmpl->grmr, token); const int n_smpl = llama_sampler_chain_n(gsmpl->chain);
}
for (int i = 0; i < n_smpl; i++) {
auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
// the grammar sampler is always the first one
if (i == 0) {
if (accept_grammar) {
llama_sampler_accept(smpl, token);
}
} else {
llama_sampler_accept(smpl, token);
}
}
} else {
llama_sampler_accept(gsmpl->chain, token); llama_sampler_accept(gsmpl->chain, token);
}
gsmpl->prev.push_back(token); gsmpl->prev.push_back(token);
} }
@ -330,8 +353,8 @@ void common_sampler_reset(struct common_sampler * gsmpl) {
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) { struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
return new common_sampler { return new common_sampler {
/* .params = */ gsmpl->params, /* .params = */ gsmpl->params,
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
/* .chain = */ llama_sampler_clone(gsmpl->chain), /* .chain = */ llama_sampler_clone(gsmpl->chain),
/* .grammar = */ gsmpl->grammar,
/* .prev = */ gsmpl->prev, /* .prev = */ gsmpl->prev,
/* .cur = */ gsmpl->cur, /* .cur = */ gsmpl->cur,
/* .cur_p = */ gsmpl->cur_p, /* .cur_p = */ gsmpl->cur_p,
@ -383,58 +406,33 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
} }
} }
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) { struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
return gsmpl->chain;
}
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx) {
llama_synchronize(ctx); llama_synchronize(ctx);
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations // start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
const auto tm = gsmpl->tm(); const auto tm = gsmpl->tm();
gsmpl->set_logits(ctx, idx); llama_token id = LLAMA_TOKEN_NULL;
auto & grmr = gsmpl->grmr;
auto & chain = gsmpl->chain; auto & chain = gsmpl->chain;
auto & cur_p = gsmpl->cur_p; // initialized by set_logits auto & cur_p = gsmpl->cur_p; // initialized by set_logits
if (grammar_first) { gsmpl->set_logits(ctx, idx);
llama_sampler_apply(grmr, &cur_p);
}
llama_sampler_apply(chain, &cur_p); llama_sampler_apply(chain, &cur_p);
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration"); GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
const llama_token id = cur_p.data[cur_p.selected].id; id = cur_p.data[cur_p.selected].id;
if (grammar_first) {
return id; return id;
} }
// check if it the sampled token fits the grammar std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft) {
{
llama_token_data single_token_data = { id, 1.0f, 0.0f };
llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
llama_sampler_apply(grmr, &single_token_data_array);
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
if (is_valid) {
return id;
}
}
// resampling:
// if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
gsmpl->set_logits(ctx, idx);
llama_sampler_apply(grmr, &cur_p);
llama_sampler_apply(chain, &cur_p);
GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");
return cur_p.data[cur_p.selected].id;
}
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1"); GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
std::vector<llama_token> result; std::vector<llama_token> result;
@ -442,7 +440,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
size_t i = 0; size_t i = 0;
for (; i < draft.size(); i++) { for (; i < draft.size(); i++) {
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
common_sampler_accept(gsmpl, id, true); common_sampler_accept(gsmpl, id, true);
@ -454,7 +452,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
} }
if (i == draft.size()) { if (i == draft.size()) {
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
common_sampler_accept(gsmpl, id, true); common_sampler_accept(gsmpl, id, true);
@ -464,13 +462,13 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
return result; return result;
} }
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) { std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft) {
std::vector<int> idxs(draft.size() + 1); std::vector<int> idxs(draft.size() + 1);
for (size_t i = 0; i < idxs.size(); ++i) { for (size_t i = 0; i < idxs.size(); ++i) {
idxs[i] = i; idxs[i] = i;
} }
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first); return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft);
} }
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) { uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
@ -515,7 +513,8 @@ std::string common_sampler_print(const struct common_sampler * gsmpl) {
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) { for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i); const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
result += std::string("-> ") + llama_sampler_name(smpl) + " "; result += std::string("-> ");
result += std::string(llama_sampler_name(smpl)) + " ";
} }
return result; return result;

View File

@ -48,6 +48,8 @@ struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);
// arguments can be nullptr to skip printing // arguments can be nullptr to skip printing
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl); void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl);
struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl);
// extended sampling implementation: // extended sampling implementation:
// //
// - set logits // - set logits
@ -55,10 +57,7 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
// - check if the token fits the grammar (if any) // - check if the token fits the grammar (if any)
// - if not: resample by first applying the grammar constraints and then sampling again (slower path) // - if not: resample by first applying the grammar constraints and then sampling again (slower path)
// //
// if grammar_first is true, the grammar is applied before the samplers (slower) llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx);
// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
//
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
// generalized version of common_sampler_sample // generalized version of common_sampler_sample
// //
@ -76,10 +75,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
// //
// returns at least 1 token, up to idxs.size() // returns at least 1 token, up to idxs.size()
// //
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false); std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft);
// assume idxs == [ 0, 1, 2, ..., draft.size() ] // assume idxs == [ 0, 1, 2, ..., draft.size() ]
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false); std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft);
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl); uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
@ -107,3 +106,9 @@ std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std:
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
const char * grammar_kind, const char * grammar_data); const char * grammar_kind, const char * grammar_data);
struct common_sampler_deleter {
void operator()(common_sampler * s) { common_sampler_free(s); }
};
typedef std::unique_ptr<common_sampler, common_sampler_deleter> common_sampler_ptr;

View File

@ -315,7 +315,7 @@ llama_tokens common_speculative_gen_draft(
for (int i = 0; i < params.n_draft; ++i) { for (int i = 0; i < params.n_draft; ++i) {
common_batch_clear(batch); common_batch_clear(batch);
common_sampler_sample(smpl, ctx_dft, 0, true); common_sampler_sample(smpl, ctx_dft, 0);
const auto * cur_p = common_sampler_get_candidates(smpl, true); const auto * cur_p = common_sampler_get_candidates(smpl, true);

View File

@ -2,6 +2,7 @@
#include "common.h" #include "common.h"
#include "log.h" #include "log.h"
#include "llama.h" #include "llama.h"
#include "sampling.h"
#include <algorithm> #include <algorithm>
#include <cstdio> #include <cstdio>
@ -64,11 +65,12 @@ int main(int argc, char ** argv) {
ctx_params.n_ctx = n_kv_req; ctx_params.n_ctx = n_kv_req;
ctx_params.n_batch = std::max(n_predict, n_parallel); ctx_params.n_batch = std::max(n_predict, n_parallel);
llama_context * ctx = llama_init_from_model(model, ctx_params);
auto sparams = llama_sampler_chain_default_params(); auto sparams = llama_sampler_chain_default_params();
sparams.no_perf = false; sparams.no_perf = false;
std::vector<llama_sampler *> samplers;
for (int32_t i = 0; i < n_parallel; ++i) {
llama_sampler * smpl = llama_sampler_chain_init(sparams); llama_sampler * smpl = llama_sampler_chain_init(sparams);
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k)); llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k));
@ -76,6 +78,11 @@ int main(int argc, char ** argv) {
llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp)); llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp));
llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed)); llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed));
samplers.push_back(smpl);
}
llama_context * ctx = llama_init_from_model(model, ctx_params);
if (ctx == NULL) { if (ctx == NULL) {
LOG_ERR("%s: error: failed to create the llama_context\n" , __func__); LOG_ERR("%s: error: failed to create the llama_context\n" , __func__);
return 1; return 1;
@ -173,7 +180,7 @@ int main(int argc, char ** argv) {
continue; continue;
} }
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]); const llama_token new_token_id = llama_sampler_sample(samplers[i], ctx, i_batch[i]);
// is it an end of generation? -> mark the stream as finished // is it an end of generation? -> mark the stream as finished
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_predict) { if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_predict) {
@ -229,14 +236,17 @@ int main(int argc, char ** argv) {
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
LOG("\n"); LOG("\n");
llama_perf_sampler_print(smpl); llama_perf_sampler_print(samplers[0]);
llama_perf_context_print(ctx); llama_perf_context_print(ctx);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
llama_batch_free(batch); llama_batch_free(batch);
llama_sampler_free(smpl); for (auto & sampler_config : samplers) {
llama_sampler_free(sampler_config);
}
llama_free(ctx); llama_free(ctx);
llama_model_free(model); llama_model_free(model);

View File

@ -131,10 +131,10 @@ int main(int argc, char ** argv) {
llama_numa_init(params.numa); llama_numa_init(params.numa);
// load the model // load the model
common_init_result llama_init = common_init_from_params(params); auto llama_init = common_init_from_params(params);
llama_model * model = llama_init.model.get(); auto * model = llama_init->model();
llama_context * ctx = llama_init.context.get(); auto * ctx = llama_init->context();
if (model == NULL) { if (model == NULL) {
LOG_ERR("%s: unable to load model\n", __func__); LOG_ERR("%s: unable to load model\n", __func__);

View File

@ -202,10 +202,10 @@ int main(int argc, char ** argv) {
params.warmup = false; params.warmup = false;
// init // init
common_init_result llama_init = common_init_from_params(params); auto llama_init = common_init_from_params(params);
llama_model * model = llama_init.model.get(); auto * model = llama_init->model();
llama_context * ctx = llama_init.context.get(); auto * ctx = llama_init->context();
if (model == nullptr || ctx == nullptr) { if (model == nullptr || ctx == nullptr) {
LOG_ERR("%s : failed to init\n", __func__); LOG_ERR("%s : failed to init\n", __func__);

View File

@ -55,10 +55,10 @@ int main(int argc, char ** argv) {
llama_numa_init(params.numa); llama_numa_init(params.numa);
// load the target model // load the target model
common_init_result llama_init = common_init_from_params(params); auto llama_init = common_init_from_params(params);
llama_model * model = llama_init.model.get(); auto * model = llama_init->model();
llama_context * ctx = llama_init.context.get(); auto * ctx = llama_init->context();
auto * mem = llama_get_memory(ctx); auto * mem = llama_get_memory(ctx);

View File

@ -18,16 +18,16 @@ int main(int argc, char ** argv){
llama_numa_init(params.numa); llama_numa_init(params.numa);
// load the model // load the model
common_init_result llama_init = common_init_from_params(params); auto llama_init = common_init_from_params(params);
llama_model_ptr & model = llama_init.model; auto * model = llama_init->model();
llama_context_ptr & ctx = llama_init.context; auto * ctx = llama_init->context();
GGML_ASSERT(model != nullptr); GGML_ASSERT(model != nullptr);
// tokenize the prompt // tokenize the prompt
std::vector<llama_token> inp; std::vector<llama_token> inp;
inp = common_tokenize(ctx.get(), params.prompt, true, true); inp = common_tokenize(ctx, params.prompt, true, true);
fprintf(stderr, "%s: tokenization done\n", __func__); fprintf(stderr, "%s: tokenization done\n", __func__);
common_ngram_cache ngram_cache; common_ngram_cache ngram_cache;

View File

@ -28,13 +28,13 @@ int main(int argc, char ** argv){
llama_numa_init(params.numa); llama_numa_init(params.numa);
// load the model // load the model
common_init_result llama_init = common_init_from_params(params); auto llama_init = common_init_from_params(params);
llama_context_ptr & ctx = llama_init.context; llama_context * ctx = llama_init->context();
// tokenize the prompt // tokenize the prompt
std::vector<llama_token> inp; std::vector<llama_token> inp;
inp = common_tokenize(ctx.get(), params.prompt, true, true); inp = common_tokenize(ctx, params.prompt, true, true);
common_ngram_cache ngram_cache_context; common_ngram_cache ngram_cache_context;
common_ngram_cache ngram_cache_dynamic; common_ngram_cache ngram_cache_dynamic;
@ -65,7 +65,7 @@ int main(int argc, char ** argv){
} }
const int n_input = inp.size(); const int n_input = inp.size();
const int n_ctx = llama_n_ctx(ctx.get()); const int n_ctx = llama_n_ctx(ctx);
int n_drafted = 0; int n_drafted = 0;
int n_accept = 0; int n_accept = 0;

View File

@ -29,10 +29,10 @@ int main(int argc, char ** argv){
llama_numa_init(params.numa); llama_numa_init(params.numa);
// load the model // load the model
common_init_result llama_init = common_init_from_params(params); auto llama_init = common_init_from_params(params);
llama_model * model = llama_init.model.get(); auto * model = llama_init->model();
llama_context * ctx = llama_init.context.get(); auto * ctx = llama_init->context();
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);

View File

@ -192,10 +192,10 @@ int main(int argc, char ** argv) {
llama_numa_init(params.numa); llama_numa_init(params.numa);
// load the target model // load the target model
common_init_result llama_init = common_init_from_params(params); auto llama_init = common_init_from_params(params);
llama_model * model = llama_init.model.get(); auto * model = llama_init->model();
llama_context * ctx = llama_init.context.get(); auto * ctx = llama_init->context();
auto * mem = llama_get_memory(ctx); auto * mem = llama_get_memory(ctx);

View File

@ -149,10 +149,10 @@ int main(int argc, char ** argv) {
llama_numa_init(params.numa); llama_numa_init(params.numa);
// load the model // load the model
common_init_result llama_init = common_init_from_params(params); auto llama_init = common_init_from_params(params);
llama_model * model = llama_init.model.get(); auto * model = llama_init->model();
llama_context * ctx = llama_init.context.get(); auto * ctx = llama_init->context();
if (model == NULL) { if (model == NULL) {
LOG_ERR("%s: unable to load model\n", __func__); LOG_ERR("%s: unable to load model\n", __func__);

View File

@ -34,10 +34,10 @@ int main(int argc, char ** argv) {
std::string result2; std::string result2;
// init // init
common_init_result llama_init = common_init_from_params(params); auto llama_init = common_init_from_params(params);
llama_model * model = llama_init.model.get(); auto * model = llama_init->model();
llama_context * ctx = llama_init.context.get(); auto * ctx = llama_init->context();
if (model == nullptr || ctx == nullptr) { if (model == nullptr || ctx == nullptr) {
fprintf(stderr, "%s : failed to init\n", __func__); fprintf(stderr, "%s : failed to init\n", __func__);

View File

@ -40,10 +40,10 @@ int main(int argc, char ** argv) {
llama_context * ctx_dft = NULL; llama_context * ctx_dft = NULL;
// load the target model // load the target model
common_init_result llama_init_tgt = common_init_from_params(params); auto llama_init_tgt = common_init_from_params(params);
model_tgt = llama_init_tgt.model.get(); model_tgt = llama_init_tgt->model();
ctx_tgt = llama_init_tgt.context.get(); ctx_tgt = llama_init_tgt->context();
const llama_vocab * vocab = llama_model_get_vocab(model_tgt); const llama_vocab * vocab = llama_model_get_vocab(model_tgt);
@ -61,10 +61,10 @@ int main(int argc, char ** argv) {
params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads; params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
params.tensor_buft_overrides = params.speculative.tensor_buft_overrides; params.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
common_init_result llama_init_dft = common_init_from_params(params); auto llama_init_dft = common_init_from_params(params);
//model_dft = llama_init_dft.model.get(); //model_dft = llama_init_dft->model();
ctx_dft = llama_init_dft.context.get(); ctx_dft = llama_init_dft->context();
if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) { if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) {
LOG_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params.speculative.model.path.c_str(), params.model.path.c_str()); LOG_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params.speculative.model.path.c_str(), params.model.path.c_str());

View File

@ -71,10 +71,10 @@ int main(int argc, char ** argv) {
llama_context * ctx_dft = NULL; llama_context * ctx_dft = NULL;
// load the target model // load the target model
common_init_result llama_init_tgt = common_init_from_params(params); auto llama_init_tgt = common_init_from_params(params);
model_tgt = llama_init_tgt.model.get(); model_tgt = llama_init_tgt->model();
ctx_tgt = llama_init_tgt.context.get(); ctx_tgt = llama_init_tgt->context();
// load the draft model // load the draft model
params.devices = params.speculative.devices; params.devices = params.speculative.devices;
@ -87,10 +87,10 @@ int main(int argc, char ** argv) {
params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads; params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
params.tensor_buft_overrides = params.speculative.tensor_buft_overrides; params.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
common_init_result llama_init_dft = common_init_from_params(params); auto llama_init_dft = common_init_from_params(params);
model_dft = llama_init_dft.model.get(); model_dft = llama_init_dft->model();
ctx_dft = llama_init_dft.context.get(); ctx_dft = llama_init_dft->context();
const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt); const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft); const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
@ -242,7 +242,7 @@ int main(int argc, char ** argv) {
bool accept = false; bool accept = false;
if (params.sampling.temp > 0) { if (params.sampling.temp > 0) {
// stochastic verification // stochastic verification
common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true); common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]);
auto & dist_tgt = *common_sampler_get_candidates(smpl, true); auto & dist_tgt = *common_sampler_get_candidates(smpl, true);
@ -491,7 +491,7 @@ int main(int argc, char ** argv) {
continue; continue;
} }
common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true); common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft);
const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl, true); const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl, true);

View File

@ -39,9 +39,10 @@ int main(int argc, char ** argv) {
llama_backend_init(); llama_backend_init();
llama_numa_init(params.numa); llama_numa_init(params.numa);
// load the model and apply lora adapter, if any // load the model and apply lora adapter, if any
common_init_result llama_init = common_init_from_params(params); auto llama_init = common_init_from_params(params);
llama_model_ptr & model = llama_init.model;
llama_context_ptr & ctx = llama_init.context; auto * model = llama_init->model();
auto * ctx = llama_init->context();
if (model == NULL) { if (model == NULL) {
LOG_ERR("%s: unable to load model\n", __func__); LOG_ERR("%s: unable to load model\n", __func__);
@ -54,8 +55,8 @@ int main(int argc, char ** argv) {
LOG_INF("%s\n", common_params_get_system_info(params).c_str()); LOG_INF("%s\n", common_params_get_system_info(params).c_str());
} }
std::vector<llama_token> tokens = common_tokenize(ctx.get(), params.prompt, true); std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, true);
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get()) / 2); ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx, tokens, llama_n_ctx(ctx) / 2);
struct lr_opt & lr = params.lr; struct lr_opt & lr = params.lr;
LOG_INF("-optimizer %s -lr0 %.2g -wd %.2g -lr-min %.2g -min-epochs %.2g -epochs %d -period %.2g -val %.2g\n", LOG_INF("-optimizer %s -lr0 %.2g -wd %.2g -lr-min %.2g -min-epochs %.2g -epochs %d -period %.2g -val %.2g\n",
@ -70,7 +71,7 @@ int main(int argc, char ** argv) {
/*get_opt_pars_ud =*/&params.lr, /*get_opt_pars_ud =*/&params.lr,
/*optimizer_type =*/params.optimizer, /*optimizer_type =*/params.optimizer,
}; };
llama_opt_init(ctx.get(), model.get(), lopt_params); llama_opt_init(ctx, model, lopt_params);
const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - params.val_split); const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - params.val_split);
@ -78,7 +79,7 @@ int main(int argc, char ** argv) {
ggml_opt_result_t result_eval = ggml_opt_result_init(); ggml_opt_result_t result_eval = ggml_opt_result_init();
for (lr.epoch = 0; lr.epoch < lr.epochs; ++lr.epoch) { for (lr.epoch = 0; lr.epoch < lr.epochs; ++lr.epoch) {
llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split, llama_opt_epoch(ctx, dataset, result_train, result_eval, idata_split,
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar); ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
@ -88,7 +89,7 @@ int main(int argc, char ** argv) {
ggml_opt_result_free(result_train); ggml_opt_result_free(result_train);
ggml_opt_result_free(result_eval); ggml_opt_result_free(result_eval);
llama_model_save_to_file(model.get(), params.out_file.c_str()); llama_model_save_to_file(model, params.out_file.c_str());
llama_backend_free(); llama_backend_free();

View File

@ -141,13 +141,15 @@ int main(int argc, char ** argv) {
// load the model and apply lora adapter, if any // load the model and apply lora adapter, if any
LOG_INF("%s: load the model and apply lora adapter, if any\n", __func__); LOG_INF("%s: load the model and apply lora adapter, if any\n", __func__);
common_init_result llama_init = common_init_from_params(params);
model = llama_init.model.get(); auto llama_init = common_init_from_params(params);
ctx = llama_init.context.get();
if (model == NULL) { ctx = llama_init->context();
LOG_ERR("%s: error: unable to load model\n", __func__); model = llama_init->model();
smpl = llama_init->sampler(0);
if (ctx == NULL) {
LOG_ERR("%s: error: unable to create context\n", __func__);
return 1; return 1;
} }
@ -474,12 +476,6 @@ int main(int argc, char ** argv) {
} }
} }
smpl = common_sampler_init(model, sparams);
if (!smpl) {
LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__);
return 1;
}
LOG_INF("sampler seed: %u\n", common_sampler_get_seed(smpl)); LOG_INF("sampler seed: %u\n", common_sampler_get_seed(smpl));
LOG_INF("sampler params: \n%s\n", sparams.print().c_str()); LOG_INF("sampler params: \n%s\n", sparams.print().c_str());
LOG_INF("sampler chain: %s\n", common_sampler_print(smpl).c_str()); LOG_INF("sampler chain: %s\n", common_sampler_print(smpl).c_str());
@ -993,8 +989,6 @@ int main(int argc, char ** argv) {
LOG("\n\n"); LOG("\n\n");
common_perf_print(ctx, smpl); common_perf_print(ctx, smpl);
common_sampler_free(smpl);
llama_backend_free(); llama_backend_free();
ggml_threadpool_free_fn(threadpool); ggml_threadpool_free_fn(threadpool);

View File

@ -419,10 +419,10 @@ int main(int argc, char ** argv) {
llama_numa_init(params.numa); llama_numa_init(params.numa);
// load the model to get hparams // load the model to get hparams
common_init_result llama_init = common_init_from_params(params); auto llama_init = common_init_from_params(params);
llama_model * model = llama_init.model.get(); auto * model = llama_init->model();
llama_context * ctx = llama_init.context.get(); auto * ctx = llama_init->context();
// int n_ctx = llama_n_ctx(ctx); // int n_ctx = llama_n_ctx(ctx);
int n_layers = llama_model_n_layer(model); int n_layers = llama_model_n_layer(model);

View File

@ -1265,10 +1265,10 @@ int main(int argc, char ** argv) {
params.warmup = false; params.warmup = false;
// init // init
common_init_result llama_init = common_init_from_params(params); auto llama_init = common_init_from_params(params);
llama_model * model = llama_init.model.get(); auto * model = llama_init->model();
llama_context * ctx = llama_init.context.get(); auto * ctx = llama_init->context();
if (model == nullptr || ctx == nullptr) { if (model == nullptr || ctx == nullptr) {
LOG_ERR("%s : failed to init\n", __func__); LOG_ERR("%s : failed to init\n", __func__);

View File

@ -65,7 +65,7 @@ static void sigint_handler(int signo) {
struct mtmd_cli_context { struct mtmd_cli_context {
mtmd::context_ptr ctx_vision; mtmd::context_ptr ctx_vision;
common_init_result llama_init; common_init_result_ptr llama_init;
llama_model * model; llama_model * model;
llama_context * lctx; llama_context * lctx;
@ -89,8 +89,8 @@ struct mtmd_cli_context {
llama_pos n_past = 0; llama_pos n_past = 0;
mtmd_cli_context(common_params & params) : llama_init(common_init_from_params(params)) { mtmd_cli_context(common_params & params) : llama_init(common_init_from_params(params)) {
model = llama_init.model.get(); model = llama_init->model();
lctx = llama_init.context.get(); lctx = llama_init->context();
vocab = llama_model_get_vocab(model); vocab = llama_model_get_vocab(model);
smpl = common_sampler_init(model, params.sampling); smpl = common_sampler_init(model, params.sampling);
n_threads = params.cpuparams.n_threads; n_threads = params.cpuparams.n_threads;

View File

@ -2024,10 +2024,10 @@ int main(int argc, char ** argv) {
llama_numa_init(params.numa); llama_numa_init(params.numa);
// load the model and apply lora adapter, if any // load the model and apply lora adapter, if any
common_init_result llama_init = common_init_from_params(params); auto llama_init = common_init_from_params(params);
llama_model * model = llama_init.model.get(); auto * model = llama_init->model();
llama_context * ctx = llama_init.context.get(); auto * ctx = llama_init->context();
if (model == NULL) { if (model == NULL) {
LOG_ERR("%s: unable to load model\n", __func__); LOG_ERR("%s: unable to load model\n", __func__);

View File

@ -153,7 +153,7 @@ struct server_slot {
// sampling // sampling
json json_schema; json json_schema;
struct common_sampler * smpl = nullptr; common_sampler_ptr smpl;
llama_token sampled; // in speculative mode, this is the last accepted token llama_token sampled; // in speculative mode, this is the last accepted token
llama_tokens drafted; llama_tokens drafted;
@ -510,8 +510,8 @@ struct server_context_impl {
common_params params_base; common_params params_base;
// note: keep these alive - they determine the lifetime of the model, context, etc. // note: keep these alive - they determine the lifetime of the model, context, etc.
common_init_result llama_init; common_init_result_ptr llama_init;
common_init_result llama_init_dft; common_init_result_ptr llama_init_dft;
llama_model * model = nullptr; llama_model * model = nullptr;
llama_context * ctx = nullptr; llama_context * ctx = nullptr;
@ -557,9 +557,6 @@ struct server_context_impl {
// Clear any sampling context // Clear any sampling context
for (server_slot & slot : slots) { for (server_slot & slot : slots) {
common_sampler_free(slot.smpl);
slot.smpl = nullptr;
llama_free(slot.ctx_dft); llama_free(slot.ctx_dft);
slot.ctx_dft = nullptr; slot.ctx_dft = nullptr;
@ -580,8 +577,8 @@ struct server_context_impl {
llama_init = common_init_from_params(params_base); llama_init = common_init_from_params(params_base);
model = llama_init.model.get(); model = llama_init->model();
ctx = llama_init.context.get(); ctx = llama_init->context();
if (model == nullptr) { if (model == nullptr) {
SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str()); SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str());
@ -613,25 +610,25 @@ struct server_context_impl {
llama_init_dft = common_init_from_params(params_dft); llama_init_dft = common_init_from_params(params_dft);
model_dft = llama_init_dft.model.get(); model_dft = llama_init_dft->model();
if (model_dft == nullptr) { if (model_dft == nullptr) {
SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str()); SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str());
return false; return false;
} }
vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft.context.get()); vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft->context());
if (!vocab_dft_compatible) { if (!vocab_dft_compatible) {
SRV_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str()); SRV_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str());
} }
const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get()); const int n_ctx_dft = llama_n_ctx(llama_init_dft->context());
cparams_dft = common_context_params_to_llama(params_dft); cparams_dft = common_context_params_to_llama(params_dft);
cparams_dft.n_batch = n_ctx_dft; cparams_dft.n_batch = n_ctx_dft;
// the context is not needed - we will create one for each slot // the context is not needed - we will create one for each slot
llama_init_dft.context.reset(); llama_init_dft->free_context();
} }
chat_templates = common_chat_templates_init(model, params_base.chat_template); chat_templates = common_chat_templates_init(model, params_base.chat_template);
@ -1051,18 +1048,15 @@ struct server_context_impl {
// initialize samplers // initialize samplers
{ {
if (slot.smpl != nullptr) { slot.smpl.reset(common_sampler_init(model, task.params.sampling));
common_sampler_free(slot.smpl);
}
slot.smpl = common_sampler_init(model, task.params.sampling);
if (slot.smpl == nullptr) { if (slot.smpl == nullptr) {
// for now, the only error that may happen here is invalid grammar // for now, the only error that may happen here is invalid grammar
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
return false; return false;
} }
SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl).c_str()); SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str());
} }
// initialize draft batch // initialize draft batch
@ -1216,11 +1210,10 @@ struct server_context_impl {
} }
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const { void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const {
size_t n_probs = slot.task->params.sampling.n_probs; const size_t n_probs = slot.task->params.sampling.n_probs;
size_t n_vocab = llama_vocab_n_tokens(vocab);
if (post_sampling) { if (post_sampling) {
const auto * cur_p = common_sampler_get_candidates(slot.smpl, true); const auto * cur_p = common_sampler_get_candidates(slot.smpl.get(), true);
const size_t max_probs = cur_p->size; const size_t max_probs = cur_p->size;
// set probability for sampled token // set probability for sampled token
@ -1245,7 +1238,7 @@ struct server_context_impl {
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx); std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
// set probability for sampled token // set probability for sampled token
for (size_t i = 0; i < n_vocab; i++) { for (size_t i = 0; i < cur.size(); i++) {
// set probability for sampled token // set probability for sampled token
if (cur[i].id == result.tok) { if (cur[i].id == result.tok) {
result.prob = cur[i].p; result.prob = cur[i].p;
@ -1255,7 +1248,7 @@ struct server_context_impl {
// set probability for top n_probs tokens // set probability for top n_probs tokens
result.probs.reserve(n_probs); result.probs.reserve(n_probs);
for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) { for (size_t i = 0; i < std::min(cur.size(), n_probs); i++) {
result.probs.push_back({ result.probs.push_back({
cur[i].id, cur[i].id,
common_token_to_piece(ctx, cur[i].id, special), common_token_to_piece(ctx, cur[i].id, special),
@ -2301,13 +2294,13 @@ struct server_context_impl {
GGML_ASSERT(batch.n_tokens > 0); GGML_ASSERT(batch.n_tokens > 0);
common_sampler_reset(slot.smpl); common_sampler_reset(slot.smpl.get());
// Process all prompt tokens through sampler system // Process all prompt tokens through sampler system
for (int i = 0; i < slot.task->n_tokens(); ++i) { for (int i = 0; i < slot.task->n_tokens(); ++i) {
llama_token id = input_tokens[i]; llama_token id = input_tokens[i];
if (id != LLAMA_TOKEN_NULL) { if (id != LLAMA_TOKEN_NULL) {
common_sampler_accept(slot.smpl, id, false); common_sampler_accept(slot.smpl.get(), id, false);
} }
} }
@ -2525,11 +2518,11 @@ struct server_context_impl {
const int tok_idx = slot.i_batch - i; const int tok_idx = slot.i_batch - i;
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); llama_token id = common_sampler_sample(slot.smpl.get(), ctx, tok_idx);
slot.i_batch = -1; slot.i_batch = -1;
common_sampler_accept(slot.smpl, id, true); common_sampler_accept(slot.smpl.get(), id, true);
slot.n_decoded += 1; slot.n_decoded += 1;
@ -2570,7 +2563,7 @@ struct server_context_impl {
size_t n_draft = slot.drafted.size(); size_t n_draft = slot.drafted.size();
// the accepted tokens from the speculation // the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, slot.i_batch_dft, slot.drafted); const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, slot.i_batch_dft, slot.drafted);
slot.i_batch_dft.clear(); slot.i_batch_dft.clear();
slot.drafted.clear(); slot.drafted.clear();

View File

@ -684,7 +684,7 @@ def test_anthropic_streaming_content_block_indices():
# Request that might produce both text and tool use # Request that might produce both text and tool use
res = server.make_stream_request("POST", "/v1/messages", data={ res = server.make_stream_request("POST", "/v1/messages", data={
"model": "test", "model": "test",
"max_tokens": 200, "max_tokens": 400,
"stream": True, "stream": True,
"tools": [{ "tools": [{
"name": "test_tool", "name": "test_tool",

View File

@ -568,10 +568,10 @@ int main(int argc, char ** argv) {
llama_context * ctx_ttc = NULL; llama_context * ctx_ttc = NULL;
llama_context * ctx_cts = NULL; llama_context * ctx_cts = NULL;
common_init_result llama_init_ttc = common_init_from_params(params); auto llama_init_ttc = common_init_from_params(params);
model_ttc = llama_init_ttc.model.get(); model_ttc = llama_init_ttc->model();
ctx_ttc = llama_init_ttc.context.get(); ctx_ttc = llama_init_ttc->context();
if (model_ttc == nullptr || ctx_ttc == nullptr) { if (model_ttc == nullptr || ctx_ttc == nullptr) {
return ENOENT; return ENOENT;
@ -583,10 +583,10 @@ int main(int argc, char ** argv) {
params.embedding = true; params.embedding = true;
params.n_ubatch = params.n_batch; params.n_ubatch = params.n_batch;
common_init_result llama_init_cts = common_init_from_params(params); auto llama_init_cts = common_init_from_params(params);
model_cts = llama_init_cts.model.get(); model_cts = llama_init_cts->model();
ctx_cts = llama_init_cts.context.get(); ctx_cts = llama_init_cts->context();
if (model_cts == nullptr || ctx_cts == nullptr) { if (model_cts == nullptr || ctx_cts == nullptr) {
return ENOENT; return ENOENT;