refactor : simplify and improve memory management

This commit is contained in:
Georgi Gerganov 2025-11-28 11:47:59 +02:00
parent 459b7ae7b9
commit 117e2079a9
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
29 changed files with 424 additions and 449 deletions

View File

@ -950,31 +950,40 @@ std::vector<common_file_info> fs_list_files(const std::string & path) {
// Model utils
//
static inline void common_init_sampler_from_model(
// TODO: move to common/sampling
static void common_init_sampler_from_model(
const llama_model * model,
common_params_sampling & sparams) {
const uint64_t config = sparams.user_sampling_config;
auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) {
if (config & user_config) return;
if (config & user_config) {
return;
}
char buf[64] = {0};
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
char * end = nullptr;
int32_t v = strtol(buf, &end, 10);
if (end && end != buf) dst = v;
if (end && end != buf) {
dst = v;
}
}
};
auto get_float = [&](const char * key, float & dst, uint64_t user_config) {
if (config & user_config) return;
if (config & user_config) {
return;
}
char buf[128] = {0};
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
char * end = nullptr;
float v = strtof(buf, &end);
if (end && end != buf) dst = v;
if (end && end != buf) {
dst = v;
}
}
};
@ -1002,45 +1011,130 @@ static inline void common_init_sampler_from_model(
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA);
}
struct common_init_result common_init_from_params(common_params & params) {
common_init_result iparams;
auto mparams = common_model_params_to_llama(params);
struct common_init_result::impl {
impl() = default;
~impl() = default;
llama_model_ptr model;
llama_context_ptr context;
std::vector<llama_adapter_lora_ptr> lora;
std::vector<common_sampler_ptr> samplers;
std::vector<llama_sampler_seq_config> samplers_seq_config;
};
common_init_result::common_init_result(common_params & params) :
pimpl(new impl{}) {
const auto mparams = common_model_params_to_llama(params);
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
if (model == NULL) {
LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
__func__, params.model.path.c_str());
return iparams;
return;
}
common_init_sampler_from_model(model, params.sampling);
pimpl->model.reset(model);
const llama_vocab * vocab = llama_model_get_vocab(model);
// updates params.sampling
// TODO: fix naming
common_init_sampler_from_model(model, params.sampling);
auto cparams = common_context_params_to_llama(params);
if (params.sampling.backend_sampling) {
llama_sampler * backend_chain = common_sampler_backend_init(model, params.sampling);
if (backend_chain != nullptr) {
iparams.samplers_seq_config.resize(cparams.n_seq_max);
for (int i = 0; i < (int) cparams.n_seq_max; ++i) {
iparams.samplers_seq_config[i] = { i, llama_sampler_clone(backend_chain) };
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
params.sampling.ignore_eos = false;
}
cparams.samplers = iparams.samplers_seq_config.data();
cparams.n_samplers = cparams.n_seq_max;
llama_sampler_free(backend_chain);
// initialize once
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
if (llama_vocab_is_eog(vocab, i)) {
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(vocab, i).c_str(), -INFINITY);
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
}
}
if (params.sampling.ignore_eos) {
// add EOG biases to the active set of logit biases
params.sampling.logit_bias.insert(
params.sampling.logit_bias.end(),
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
}
//if (params.sampling.penalty_last_n == -1) {
// LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
// params.sampling.penalty_last_n = llama_n_ctx(lctx);
//}
//if (params.sampling.dry_penalty_last_n == -1) {
// LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
// params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
//}
// init the backend samplers as part of the context creation
pimpl->samplers.resize(cparams.n_seq_max);
pimpl->samplers_seq_config.resize(cparams.n_seq_max);
for (int i = 0; i < (int) cparams.n_seq_max; ++i) {
pimpl->samplers[i].reset(common_sampler_init(model, params.sampling));
llama_sampler * backend_chain = common_sampler_chain_backend(pimpl->samplers[i].get());
pimpl->samplers_seq_config[i] = { i, backend_chain };
}
cparams.samplers = pimpl->samplers_seq_config.data();
cparams.n_samplers = pimpl->samplers_seq_config.size();
llama_context * lctx = llama_init_from_model(model, cparams);
if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
__func__, params.model.path.c_str());
llama_model_free(model);
return iparams;
return;
}
pimpl->context.reset(lctx);
}
llama_model * common_init_result::model() {
return pimpl->model.get();
}
llama_context * common_init_result::context() {
return pimpl->context.get();
}
common_sampler * common_init_result::sampler(llama_seq_id seq_id) {
return pimpl->samplers[seq_id].get();
}
std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
return pimpl->lora;
}
void common_init_result::free_context() {
pimpl->context.reset();
}
common_init_result_ptr common_init_from_params(common_params & params) {
common_init_result_ptr res(new common_init_result(params));
llama_model * model = res->model();
if (model == NULL) {
LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
__func__, params.model.path.c_str());
return res;
}
llama_context * lctx = res->context();
if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
__func__, params.model.path.c_str());
return res;
}
const llama_vocab * vocab = llama_model_get_vocab(model);
if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) {
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
params.ctx_shift = false;
@ -1052,10 +1146,7 @@ struct common_init_result common_init_from_params(common_params & params) {
const auto cvec = common_control_vector_load(params.control_vectors);
if (cvec.n_embd == -1) {
llama_free(lctx);
llama_model_free(model);
return iparams;
return res;
}
int err = llama_apply_adapter_cvec(
@ -1066,10 +1157,7 @@ struct common_init_result common_init_from_params(common_params & params) {
params.control_vector_layer_start,
params.control_vector_layer_end);
if (err) {
llama_free(lctx);
llama_model_free(model);
return iparams;
return res;
}
}
@ -1093,10 +1181,7 @@ struct common_init_result common_init_from_params(common_params & params) {
}
if (!ok) {
llama_free(lctx);
llama_model_free(model);
return iparams;
return res;
}
}
@ -1106,9 +1191,7 @@ struct common_init_result common_init_from_params(common_params & params) {
lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
if (lora == nullptr) {
LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
llama_free(lctx);
llama_model_free(model);
return iparams;
return res;
}
char buf[1024];
@ -1117,43 +1200,13 @@ struct common_init_result common_init_from_params(common_params & params) {
la.task_name = buf;
llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf));
la.prompt_prefix = buf;
iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
res->lora().emplace_back(std::move(lora)); // copy to list of loaded adapters
}
if (!params.lora_init_without_apply) {
common_set_adapter_lora(lctx, params.lora_adapters);
}
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
params.sampling.ignore_eos = false;
}
// initialize once
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
if (llama_vocab_is_eog(vocab, i)) {
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
}
}
if (params.sampling.ignore_eos) {
// add EOG biases to the active set of logit biases
params.sampling.logit_bias.insert(
params.sampling.logit_bias.end(),
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
}
if (params.sampling.penalty_last_n == -1) {
LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
params.sampling.penalty_last_n = llama_n_ctx(lctx);
}
if (params.sampling.dry_penalty_last_n == -1) {
LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
}
if (params.warmup) {
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
@ -1192,12 +1245,11 @@ struct common_init_result common_init_from_params(common_params & params) {
llama_set_warmup(lctx, false);
}
iparams.model.reset(model);
iparams.context.reset(lctx);
return iparams;
return res;
}
common_init_result::~common_init_result() = default;
std::string get_model_endpoint() {
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
// We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
@ -1206,7 +1258,9 @@ std::string get_model_endpoint() {
std::string model_endpoint = "https://huggingface.co/";
if (endpoint_env) {
model_endpoint = endpoint_env;
if (model_endpoint.back() != '/') model_endpoint += '/';
if (model_endpoint.back() != '/') {
model_endpoint += '/';
}
}
return model_endpoint;
}

View File

@ -192,7 +192,6 @@ struct common_params_sampling {
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
std::vector<enum common_sampler_type> samplers = {
COMMON_SAMPLER_TYPE_PENALTIES,
COMMON_SAMPLER_TYPE_DRY,
@ -215,6 +214,16 @@ struct common_params_sampling {
bool backend_sampling = false; // enable backend sampling
bool has_logit_bias() const {
return !logit_bias.empty();
}
bool is_disabled(enum common_sampler_type type) const;
// remove disabled samplers
// TODO: temporary until all samplers have llama_sampler_backend_ API [LLAMA_SAMPLER_BACKEND]
void filter_disabled();
// print the parameters into a string
std::string print() const;
};
@ -650,18 +659,29 @@ std::vector<common_file_info> fs_list_files(const std::string & path);
// Model utils
//
struct common_sampler;
// note: defines object's lifetime
struct common_init_result {
llama_model_ptr model;
llama_context_ptr context;
common_init_result(common_params & params);
~common_init_result();
std::vector<llama_adapter_lora_ptr> lora;
llama_model * model();
llama_context * context();
common_sampler * sampler(llama_seq_id seq_id);
std::vector<llama_sampler_ptr> samplers;
std::vector<llama_sampler_seq_config> samplers_seq_config;
std::vector<llama_adapter_lora_ptr> & lora();
void free_context();
private:
struct impl;
std::unique_ptr<impl> pimpl;
};
struct common_init_result common_init_from_params(common_params & params);
using common_init_result_ptr = std::unique_ptr<common_init_result>;
common_init_result_ptr common_init_from_params(common_params & params);
struct llama_model_params common_model_params_to_llama ( common_params & params);
struct llama_context_params common_context_params_to_llama(const common_params & params);

View File

@ -105,7 +105,8 @@ struct common_sampler {
common_params_sampling params;
struct llama_sampler * grmr;
struct llama_sampler * chain; // CPU sampling chain
struct llama_sampler * chain;
struct llama_sampler * chain_backend;
ring_buffer<llama_token> prev;
@ -118,6 +119,7 @@ struct common_sampler {
llama_sampler_reset(grmr);
llama_sampler_reset(chain);
llama_sampler_reset(chain_backend);
}
void set_logits(struct llama_context * ctx, int idx) {
@ -161,7 +163,8 @@ struct common_sampler {
mutable int64_t t_total_us = 0;
};
static bool sampler_backend_supported(enum common_sampler_type type) {
// TODO: temporary until all samplers have llama_sampler_backend_ API [LLAMA_SAMPLER_BACKEND]
static bool common_sampler_type_has_backend_support(enum common_sampler_type type) {
switch (type) {
case COMMON_SAMPLER_TYPE_TOP_K:
case COMMON_SAMPLER_TYPE_TEMPERATURE:
@ -172,98 +175,69 @@ static bool sampler_backend_supported(enum common_sampler_type type) {
}
}
static bool is_sampler_enabled(enum common_sampler_type type, const struct common_params_sampling & params) {
bool common_params_sampling::is_disabled(enum common_sampler_type type) const {
switch (type) {
case COMMON_SAMPLER_TYPE_PENALTIES:
if (params.penalty_last_n == 64 &&
fabs(params.penalty_repeat) <= 1.0f &&
fabs(params.penalty_freq) <= 0.0f &&
fabs(params.penalty_present) <= 0.0f) {
return false;
if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
return true;
}
break;
case COMMON_SAMPLER_TYPE_DRY:
if (params.dry_multiplier == 0.0f && params.dry_base == 1.75f) {
return false;
if (dry_multiplier == 0.0f || dry_base < 1.0f || dry_penalty_last_n == 0) {
return true;
}
break;
case COMMON_SAMPLER_TYPE_TYPICAL_P:
if (params.typ_p == 1.0) {
return false;
if (typ_p >= 1.0) {
return true;
}
break;
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
if (params.top_n_sigma == -1.0) {
return false;
if (top_n_sigma <= 0.0) {
return true;
}
break;
case COMMON_SAMPLER_TYPE_TOP_K:
if (params.top_k <= 0) {
return false;
if (top_k <= 0) {
return true;
}
break;
case COMMON_SAMPLER_TYPE_TEMPERATURE:
if (params.temp < 0.0f) {
return false;
if (dynatemp_range <= 0.0f) {
return true;
}
break;
case COMMON_SAMPLER_TYPE_MIN_P:
if (params.min_p <= 0.0f) {
return false;
if (min_p <= 0.0f) {
return true;
}
break;
case COMMON_SAMPLER_TYPE_TOP_P:
if (params.top_p >= 1.0f) {
return false;
if (top_p >= 1.0f) {
return true;
}
break;
case COMMON_SAMPLER_TYPE_XTC:
if (params.xtc_probability == 0.0f && params.xtc_threshold == 0.10f) {
return false;
if (xtc_probability <= 0.0f || xtc_threshold == 0.50f) {
return true;
}
break;
default:
break;
}
return true;
return false;
}
static bool has_logit_bias(const struct common_params_sampling & params) {
return !params.logit_bias.empty();
}
struct active_samplers {
std::vector<common_sampler_type> backend_samplers;
std::vector<common_sampler_type> cpu_samplers;
};
static struct active_samplers get_active_samplers(const struct common_params_sampling & params) {
struct active_samplers result;
if (params.mirostat != 0) {
// Mirostat is CPU-only and overrides other samplers
for (const auto & sampler_type : params.samplers) {
if (is_sampler_enabled(sampler_type, params)) {
result.cpu_samplers.push_back(sampler_type);
}
}
return result;
}
bool backend_supported = params.backend_sampling;
for (const auto & sampler_type : params.samplers) {
if (!is_sampler_enabled(sampler_type, params)) {
continue;
}
if (backend_supported && sampler_backend_supported(sampler_type)) {
result.backend_samplers.push_back(sampler_type);
void common_params_sampling::filter_disabled() {
for (auto it = samplers.begin(); it != samplers.end();) {
if (is_disabled(*it)) {
LOG_WRN("%s: removing disabled sampler %s\n", __func__, common_sampler_type_to_str(*it).c_str());
it = samplers.erase(it);
} else {
result.cpu_samplers.push_back(sampler_type);
++it;
}
}
return result;
}
std::string common_params_sampling::print() const {
@ -282,15 +256,7 @@ std::string common_params_sampling::print() const {
return std::string(result);
}
struct backend_chain_data {
struct llama_sampler * chain;
size_t count;
};
static struct backend_chain_data backend_samplers_init(const struct llama_model * model, const struct common_params_sampling & params,
struct active_samplers get_active_samplers);
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params) {
const llama_vocab * vocab = llama_model_get_vocab(model);
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
@ -357,29 +323,74 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
}
}
// TODO: temporary until all samplers have llama_sampler_backend_ API [LLAMA_SAMPLER_BACKEND]
if (params.backend_sampling) {
params.filter_disabled();
}
auto * result = new common_sampler {
/* .params = */ params,
/* .grmr = */ grmr,
/* .chain = */ llama_sampler_chain_init(lparams),
/* .chain_backend = */ llama_sampler_chain_init(lparams),
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
/* .cur = */ {},
/* .cur_p = */ {},
};
struct active_samplers active_samplers = get_active_samplers(params);
size_t idx_smpl = 0;
// Build CPU chain
if (!params.backend_sampling || !has_logit_bias(params)) {
bool is_backend = true;
is_backend = is_backend && params.backend_sampling;
is_backend = is_backend && (params.samplers.size() == 0 || common_sampler_type_has_backend_support(params.samplers[idx_smpl]));
if (params.has_logit_bias()) {
if (is_backend) {
llama_sampler_chain_add(result->chain_backend,
llama_sampler_backend_init_logit_bias(
llama_vocab_n_tokens(vocab),
params.logit_bias.size(),
params.logit_bias.data()));
} else {
llama_sampler_chain_add(result->chain,
llama_sampler_init_logit_bias(
llama_vocab_n_tokens(vocab),
params.logit_bias.size(),
params.logit_bias.data()));
}
}
if (params.mirostat == 0) {
// backend samplers are added first
while (is_backend && idx_smpl < params.samplers.size()) {
const auto & cnstr = params.samplers[idx_smpl++];
if (!common_sampler_type_has_backend_support(cnstr)) {
is_backend = false;
--idx_smpl;
break;
}
switch (cnstr) {
case COMMON_SAMPLER_TYPE_TOP_K:
llama_sampler_chain_add(result->chain_backend, llama_sampler_backend_init_top_k(params.top_k));
break;
case COMMON_SAMPLER_TYPE_TEMPERATURE:
llama_sampler_chain_add(result->chain_backend, llama_sampler_backend_init_temp(params.temp));
break;
case COMMON_SAMPLER_TYPE_MIN_P:
llama_sampler_chain_add(result->chain_backend, llama_sampler_backend_init_min_p(params.min_p));
break;
default:
GGML_ASSERT(false && "unsupported backend sampler");
}
}
// Add remaining CPU samplers
for (const auto & cnstr : active_samplers.cpu_samplers) {
while (idx_smpl < params.samplers.size()) {
const auto & cnstr = params.samplers[idx_smpl++];
switch (cnstr) {
case COMMON_SAMPLER_TYPE_DRY:
{
@ -424,7 +435,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
}
}
if (!active_samplers.cpu_samplers.empty()) {
if (is_backend) {
llama_sampler_chain_add(result->chain_backend, llama_sampler_backend_init_dist(params.seed));
} else {
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
}
} else if (params.mirostat == 1) {
@ -440,59 +453,11 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
return result;
}
static struct backend_chain_data backend_samplers_init(const struct llama_model * model, const struct common_params_sampling & params,
struct active_samplers active_samplers) {
if (active_samplers.backend_samplers.empty()) {
return { nullptr, 0 };
}
const llama_vocab * vocab = llama_model_get_vocab(model);
llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
chain_params.no_perf = params.no_perf;
struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
// Add logit_bias to backend chain if present
if (has_logit_bias(params)) {
llama_sampler_chain_add(chain, llama_sampler_backend_init_logit_bias(
llama_vocab_n_tokens(vocab),
params.logit_bias.size(),
params.logit_bias.data()));
}
for (const auto & sampler_type : active_samplers.backend_samplers) {
switch (sampler_type) {
case COMMON_SAMPLER_TYPE_TOP_K:
llama_sampler_chain_add(chain, llama_sampler_backend_init_top_k(params.top_k));
break;
case COMMON_SAMPLER_TYPE_TEMPERATURE:
llama_sampler_chain_add(chain, llama_sampler_backend_init_temp(params.temp));
break;
case COMMON_SAMPLER_TYPE_MIN_P:
llama_sampler_chain_add(chain, llama_sampler_backend_init_min_p(params.min_p));
break;
default:
GGML_ASSERT(false && "unsupported backend sampler");
}
}
if (active_samplers.cpu_samplers.empty()) {
llama_sampler_chain_add(chain, llama_sampler_backend_init_dist(params.seed));
}
return { chain, active_samplers.backend_samplers.size() + has_logit_bias(params) };
}
struct llama_sampler * common_sampler_backend_init(const struct llama_model * model, const struct common_params_sampling & params) {
struct active_samplers active_samplers = get_active_samplers(params);
return backend_samplers_init(model, params, active_samplers).chain;
}
void common_sampler_free(struct common_sampler * gsmpl) {
if (gsmpl) {
llama_sampler_free(gsmpl->grmr);
llama_sampler_free(gsmpl->chain);
llama_sampler_free(gsmpl->chain_backend);
delete gsmpl;
}
@ -519,6 +484,7 @@ struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
/* .params = */ gsmpl->params,
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
/* .chain = */ llama_sampler_clone(gsmpl->chain),
/* .chain_backend = */ llama_sampler_clone(gsmpl->chain_backend),
/* .prev = */ gsmpl->prev,
/* .cur = */ gsmpl->cur,
/* .cur_p = */ gsmpl->cur_p,
@ -570,6 +536,10 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
}
}
struct llama_sampler * common_sampler_chain_backend(const struct common_sampler * gsmpl) {
return gsmpl->chain_backend;
}
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
// Check if a backend sampler has already sampled a token in which case we
// return that token id directly.
@ -707,7 +677,12 @@ llama_token common_sampler_last(const struct common_sampler * gsmpl) {
}
std::string common_sampler_print(const struct common_sampler * gsmpl) {
std::string result = "logits ";
std::string result = llama_sampler_chain_n(gsmpl->chain_backend) > 0 ? "*logits " : "logits ";
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain_backend); i++) {
const auto * smpl = llama_sampler_chain_get(gsmpl->chain_backend, i);
result += std::string("-> *") + llama_sampler_name(smpl) + " ";
}
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);

View File

@ -36,14 +36,8 @@ struct common_sampler;
// llama_sampler API overloads
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params);
// Create a backend sampler chain from common sampling parameters
// Returns a llama_sampler chain configured with backend samplers based on the parameters
// This chain can be used per-sequence for backend-based sampling
// Note: Only samplers that have backend equivalents will be added to the chain
// The returned sampler should be freed with llama_sampler_free()
struct llama_sampler * common_sampler_backend_init(const struct llama_model * model, const struct common_params_sampling & params);
// TODO: params should become const again [LLAMA_SAMPLER_BACKEND]
struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params);
void common_sampler_free(struct common_sampler * gsmpl);
@ -55,6 +49,8 @@ struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);
// arguments can be nullptr to skip printing
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl);
struct llama_sampler * common_sampler_chain_backend(const struct common_sampler * gsmpl);
// extended sampling implementation:
//
// - set logits
@ -114,3 +110,9 @@ std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std:
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
const char * grammar_kind, const char * grammar_data);
struct common_sampler_deleter {
void operator()(common_sampler * s) { common_sampler_free(s); }
};
typedef std::unique_ptr<common_sampler, common_sampler_deleter> common_sampler_ptr;

View File

@ -65,29 +65,34 @@ int main(int argc, char ** argv) {
ctx_params.n_ctx = n_kv_req;
ctx_params.n_batch = std::max(n_predict, n_parallel);
std::vector<llama_sampler_seq_config> sampler_configs(n_parallel);
if (params.sampling.backend_sampling) {
for (int32_t i = 0; i < n_parallel; ++i) {
llama_sampler * backend_sampler = common_sampler_backend_init(model, params.sampling);
if (backend_sampler) {
sampler_configs[i] = { i, backend_sampler };
}
}
ctx_params.samplers = sampler_configs.data();
ctx_params.n_samplers = n_parallel;
}
llama_context * ctx = llama_init_from_model(model, ctx_params);
auto sparams = llama_sampler_chain_default_params();
sparams.no_perf = false;
std::vector<llama_sampler_seq_config> sampler_configs;
for (int32_t i = 0; i < n_parallel; ++i) {
llama_sampler * smpl = llama_sampler_chain_init(sparams);
if (params.sampling.backend_sampling) {
llama_sampler_chain_add(smpl, llama_sampler_backend_init_top_k(params.sampling.top_k));
llama_sampler_chain_add(smpl, llama_sampler_backend_init_temp (params.sampling.temp));
llama_sampler_chain_add(smpl, llama_sampler_backend_init_dist (params.sampling.seed));
} else {
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k));
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep));
llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp));
llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed));
}
sampler_configs.push_back({ i, smpl });
}
if (params.sampling.backend_sampling) {
ctx_params.samplers = sampler_configs.data();
ctx_params.n_samplers = sampler_configs.size();
}
llama_context * ctx = llama_init_from_model(model, ctx_params);
if (ctx == NULL) {
LOG_ERR("%s: error: failed to create the llama_context\n" , __func__);
@ -186,7 +191,7 @@ int main(int argc, char ** argv) {
continue;
}
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]);
const llama_token new_token_id = llama_sampler_sample(sampler_configs[i].sampler, ctx, i_batch[i]);
// is it an end of generation? -> mark the stream as finished
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_predict) {
@ -242,14 +247,17 @@ int main(int argc, char ** argv) {
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
LOG("\n");
llama_perf_sampler_print(smpl);
llama_perf_sampler_print(sampler_configs[0].sampler);
llama_perf_context_print(ctx);
fprintf(stderr, "\n");
llama_batch_free(batch);
llama_sampler_free(smpl);
for (auto & sampler_config : sampler_configs) {
llama_sampler_free(sampler_config.sampler);
}
llama_free(ctx);
llama_model_free(model);

View File

@ -130,10 +130,10 @@ int main(int argc, char ** argv) {
llama_numa_init(params.numa);
// load the model
common_init_result llama_init = common_init_from_params(params);
auto llama_init = common_init_from_params(params);
llama_model * model = llama_init.model.get();
llama_context * ctx = llama_init.context.get();
auto * model = llama_init->model();
auto * ctx = llama_init->context();
if (model == NULL) {
LOG_ERR("%s: unable to load model\n", __func__);

View File

@ -202,10 +202,10 @@ int main(int argc, char ** argv) {
params.warmup = false;
// init
common_init_result llama_init = common_init_from_params(params);
auto llama_init = common_init_from_params(params);
llama_model * model = llama_init.model.get();
llama_context * ctx = llama_init.context.get();
auto * model = llama_init->model();
auto * ctx = llama_init->context();
if (model == nullptr || ctx == nullptr) {
LOG_ERR("%s : failed to init\n", __func__);

View File

@ -55,10 +55,10 @@ int main(int argc, char ** argv) {
llama_numa_init(params.numa);
// load the target model
common_init_result llama_init = common_init_from_params(params);
auto llama_init = common_init_from_params(params);
llama_model * model = llama_init.model.get();
llama_context * ctx = llama_init.context.get();
auto * model = llama_init->model();
auto * ctx = llama_init->context();
auto * mem = llama_get_memory(ctx);

View File

@ -18,16 +18,16 @@ int main(int argc, char ** argv){
llama_numa_init(params.numa);
// load the model
common_init_result llama_init = common_init_from_params(params);
auto llama_init = common_init_from_params(params);
llama_model_ptr & model = llama_init.model;
llama_context_ptr & ctx = llama_init.context;
auto * model = llama_init->model();
auto * ctx = llama_init->context();
GGML_ASSERT(model != nullptr);
// tokenize the prompt
std::vector<llama_token> inp;
inp = common_tokenize(ctx.get(), params.prompt, true, true);
inp = common_tokenize(ctx, params.prompt, true, true);
fprintf(stderr, "%s: tokenization done\n", __func__);
common_ngram_cache ngram_cache;

View File

@ -28,13 +28,13 @@ int main(int argc, char ** argv){
llama_numa_init(params.numa);
// load the model
common_init_result llama_init = common_init_from_params(params);
auto llama_init = common_init_from_params(params);
llama_context_ptr & ctx = llama_init.context;
llama_context * ctx = llama_init->context();
// tokenize the prompt
std::vector<llama_token> inp;
inp = common_tokenize(ctx.get(), params.prompt, true, true);
inp = common_tokenize(ctx, params.prompt, true, true);
common_ngram_cache ngram_cache_context;
common_ngram_cache ngram_cache_dynamic;
@ -65,7 +65,7 @@ int main(int argc, char ** argv){
}
const int n_input = inp.size();
const int n_ctx = llama_n_ctx(ctx.get());
const int n_ctx = llama_n_ctx(ctx);
int n_drafted = 0;
int n_accept = 0;

View File

@ -29,10 +29,10 @@ int main(int argc, char ** argv){
llama_numa_init(params.numa);
// load the model
common_init_result llama_init = common_init_from_params(params);
auto llama_init = common_init_from_params(params);
llama_model * model = llama_init.model.get();
llama_context * ctx = llama_init.context.get();
auto * model = llama_init->model();
auto * ctx = llama_init->context();
const llama_vocab * vocab = llama_model_get_vocab(model);

View File

@ -192,10 +192,10 @@ int main(int argc, char ** argv) {
llama_numa_init(params.numa);
// load the target model
common_init_result llama_init = common_init_from_params(params);
auto llama_init = common_init_from_params(params);
llama_model * model = llama_init.model.get();
llama_context * ctx = llama_init.context.get();
auto * model = llama_init->model();
auto * ctx = llama_init->context();
auto * mem = llama_get_memory(ctx);

View File

@ -149,10 +149,10 @@ int main(int argc, char ** argv) {
llama_numa_init(params.numa);
// load the model
common_init_result llama_init = common_init_from_params(params);
auto llama_init = common_init_from_params(params);
llama_model * model = llama_init.model.get();
llama_context * ctx = llama_init.context.get();
auto * model = llama_init->model();
auto * ctx = llama_init->context();
if (model == NULL) {
LOG_ERR("%s: unable to load model\n", __func__);

View File

@ -34,10 +34,10 @@ int main(int argc, char ** argv) {
std::string result2;
// init
common_init_result llama_init = common_init_from_params(params);
auto llama_init = common_init_from_params(params);
llama_model * model = llama_init.model.get();
llama_context * ctx = llama_init.context.get();
auto * model = llama_init->model();
auto * ctx = llama_init->context();
if (model == nullptr || ctx == nullptr) {
fprintf(stderr, "%s : failed to init\n", __func__);

View File

@ -40,10 +40,10 @@ int main(int argc, char ** argv) {
llama_context * ctx_dft = NULL;
// load the target model
common_init_result llama_init_tgt = common_init_from_params(params);
auto llama_init_tgt = common_init_from_params(params);
model_tgt = llama_init_tgt.model.get();
ctx_tgt = llama_init_tgt.context.get();
model_tgt = llama_init_tgt->model();
ctx_tgt = llama_init_tgt->context();
const llama_vocab * vocab = llama_model_get_vocab(model_tgt);
@ -61,10 +61,10 @@ int main(int argc, char ** argv) {
params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
params.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
common_init_result llama_init_dft = common_init_from_params(params);
auto llama_init_dft = common_init_from_params(params);
//model_dft = llama_init_dft.model.get();
ctx_dft = llama_init_dft.context.get();
//model_dft = llama_init_dft->model();
ctx_dft = llama_init_dft->context();
if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) {
LOG_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params.speculative.model.path.c_str(), params.model.path.c_str());

View File

@ -71,10 +71,10 @@ int main(int argc, char ** argv) {
llama_context * ctx_dft = NULL;
// load the target model
common_init_result llama_init_tgt = common_init_from_params(params);
auto llama_init_tgt = common_init_from_params(params);
model_tgt = llama_init_tgt.model.get();
ctx_tgt = llama_init_tgt.context.get();
model_tgt = llama_init_tgt->model();
ctx_tgt = llama_init_tgt->context();
// load the draft model
params.devices = params.speculative.devices;
@ -87,10 +87,10 @@ int main(int argc, char ** argv) {
params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
params.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
common_init_result llama_init_dft = common_init_from_params(params);
auto llama_init_dft = common_init_from_params(params);
model_dft = llama_init_dft.model.get();
ctx_dft = llama_init_dft.context.get();
model_dft = llama_init_dft->model();
ctx_dft = llama_init_dft->context();
const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);

View File

@ -39,9 +39,10 @@ int main(int argc, char ** argv) {
llama_backend_init();
llama_numa_init(params.numa);
// load the model and apply lora adapter, if any
common_init_result llama_init = common_init_from_params(params);
llama_model_ptr & model = llama_init.model;
llama_context_ptr & ctx = llama_init.context;
auto llama_init = common_init_from_params(params);
auto * model = llama_init->model();
auto * ctx = llama_init->context();
if (model == NULL) {
LOG_ERR("%s: unable to load model\n", __func__);
@ -54,8 +55,8 @@ int main(int argc, char ** argv) {
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
}
std::vector<llama_token> tokens = common_tokenize(ctx.get(), params.prompt, true);
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get()) / 2);
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, true);
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx, tokens, llama_n_ctx(ctx) / 2);
struct lr_opt & lr = params.lr;
LOG_INF("-optimizer %s -lr0 %.2g -wd %.2g -lr-min %.2g -min-epochs %.2g -epochs %d -period %.2g -val %.2g\n",
@ -70,7 +71,7 @@ int main(int argc, char ** argv) {
/*get_opt_pars_ud =*/&params.lr,
/*optimizer_type =*/params.optimizer,
};
llama_opt_init(ctx.get(), model.get(), lopt_params);
llama_opt_init(ctx, model, lopt_params);
const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - params.val_split);
@ -78,7 +79,7 @@ int main(int argc, char ** argv) {
ggml_opt_result_t result_eval = ggml_opt_result_init();
for (lr.epoch = 0; lr.epoch < lr.epochs; ++lr.epoch) {
llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split,
llama_opt_epoch(ctx, dataset, result_train, result_eval, idata_split,
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
fprintf(stderr, "\n");
@ -88,7 +89,7 @@ int main(int argc, char ** argv) {
ggml_opt_result_free(result_train);
ggml_opt_result_free(result_eval);
llama_model_save_to_file(model.get(), params.out_file.c_str());
llama_model_save_to_file(model, params.out_file.c_str());
llama_backend_free();

View File

@ -376,7 +376,7 @@ extern "C" {
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
// backend sampler chain configuration
// backend sampler chain configuration (does not keep a reference, so make sure the caller keeps the samplers alive)
struct llama_sampler_seq_config * samplers;
size_t n_samplers;
};
@ -1209,10 +1209,6 @@ extern "C" {
void (*init_ggml)(struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft);
// TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph
//void (*apply_ggml) (struct llama_sampler * smpl, ...);
};
struct llama_sampler {
@ -1231,14 +1227,13 @@ extern "C" {
LLAMA_API struct llama_sampler * llama_sampler_clone (const struct llama_sampler * smpl);
// important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add)
LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl);
LLAMA_API void llama_sampler_init_ggml(struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft);
LLAMA_API void llama_sampler_init_ggml (struct llama_sampler * smpl, ggml_backend_buffer_type_t buft);
LLAMA_API void llama_sampler_set_input_ggml(struct llama_sampler * smpl);
LLAMA_API void llama_sampler_apply_ggml (struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_ggml_data * ggml_data);
LLAMA_API void llama_sampler_accept_ggml (struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,

View File

@ -66,7 +66,15 @@ llama_context::llama_context(
for (size_t i = 0; i < params.n_samplers; ++i) {
const auto & config = params.samplers[i];
const int n_samplers = llama_sampler_chain_n(config.sampler);
if (n_samplers <= 0) {
continue;
}
sampling.samplers[config.seq_id] = config.sampler;
LLAMA_LOG_INFO("%s: setting backend sampler for seq_id %d (n = %d)\n", __func__, config.seq_id, n_samplers);
}
}
@ -438,8 +446,8 @@ llama_context::llama_context(
// Initialize the full vocabulary token ids for backend samplers.
{
const llama_vocab * vocab = llama_model_get_vocab(&model);
const int n_vocab = llama_vocab_n_tokens(vocab);
const int n_vocab = model.vocab.n_tokens();
sampling.token_ids_full_vocab.resize(n_vocab);
for (int i = 0; i < n_vocab; ++i) {
sampling.token_ids_full_vocab[i] = i;
@ -449,10 +457,6 @@ llama_context::llama_context(
llama_context::~llama_context() {
ggml_opt_free(opt_ctx);
// TODO: perhaps use a smart pointer for samplers
for (auto const& [seq_id, sampler] : sampling.samplers) {
llama_sampler_free(sampler);
}
}
void llama_context::synchronize() {
@ -910,31 +914,10 @@ void llama_context::set_warmup(bool value) {
void llama_context::set_backend_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler);
auto it = sampling.samplers.find(seq_id);
if (it != sampling.samplers.end()) {
// If the sampler to be set is the same that is already set, do nothing.
if (it->second == sampler) {
return;
}
llama_sampler_free(it->second);
// If sampler is nullptr, we remove the samppler chain for this seq_id.
// chain for this seq_id.
if (sampler == nullptr) {
sampling.samplers.erase(it);
return;
}
// Otherwise, we replace the existing sampler with the new one.
it->second = sampler;
return;
}
// If there is no sampler for this seq_id and the caller provides a non-null
// sampler, we set it.
if (sampler != nullptr) {
if (sampler != nullptr && llama_sampler_chain_n(sampler) > 0) {
sampling.samplers[seq_id] = sampler;
} else {
sampling.samplers.erase(seq_id);
}
}

View File

@ -419,10 +419,10 @@ int main(int argc, char ** argv) {
llama_numa_init(params.numa);
// load the model to get hparams
common_init_result llama_init = common_init_from_params(params);
auto llama_init = common_init_from_params(params);
llama_model * model = llama_init.model.get();
llama_context * ctx = llama_init.context.get();
auto * model = llama_init->model();
auto * ctx = llama_init->context();
// int n_ctx = llama_n_ctx(ctx);
int n_layers = llama_model_n_layer(model);

View File

@ -1265,10 +1265,10 @@ int main(int argc, char ** argv) {
params.warmup = false;
// init
common_init_result llama_init = common_init_from_params(params);
auto llama_init = common_init_from_params(params);
llama_model * model = llama_init.model.get();
llama_context * ctx = llama_init.context.get();
auto * model = llama_init->model();
auto * ctx = llama_init->context();
if (model == nullptr || ctx == nullptr) {
LOG_ERR("%s : failed to init\n", __func__);

View File

@ -138,9 +138,11 @@ int main(int argc, char ** argv) {
// load the model and apply lora adapter, if any
LOG_INF("%s: load the model and apply lora adapter, if any\n", __func__);
common_init_result llama_init = common_init_from_params(params);
ctx = llama_init.context.get();
model = llama_init.model.get(); // Update pointer (now managed by llama_init)
auto llama_init = common_init_from_params(params);
ctx = llama_init->context();
model = llama_init->model();
smpl = llama_init->sampler(0);
if (ctx == NULL) {
LOG_ERR("%s: error: unable to create context\n", __func__);
@ -470,12 +472,6 @@ int main(int argc, char ** argv) {
}
}
smpl = common_sampler_init(model, sparams);
if (!smpl) {
LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__);
return 1;
}
LOG_INF("sampler seed: %u\n", common_sampler_get_seed(smpl));
LOG_INF("sampler params: \n%s\n", sparams.print().c_str());
LOG_INF("sampler chain: %s\n", common_sampler_print(smpl).c_str());
@ -989,8 +985,6 @@ int main(int argc, char ** argv) {
LOG("\n\n");
common_perf_print(ctx, smpl);
common_sampler_free(smpl);
llama_backend_free();
ggml_threadpool_free_fn(threadpool);

View File

@ -65,7 +65,7 @@ static void sigint_handler(int signo) {
struct mtmd_cli_context {
mtmd::context_ptr ctx_vision;
common_init_result llama_init;
common_init_result_ptr llama_init;
llama_model * model;
llama_context * lctx;
@ -89,8 +89,8 @@ struct mtmd_cli_context {
llama_pos n_past = 0;
mtmd_cli_context(common_params & params) : llama_init(common_init_from_params(params)) {
model = llama_init.model.get();
lctx = llama_init.context.get();
model = llama_init->model();
lctx = llama_init->context();
vocab = llama_model_get_vocab(model);
smpl = common_sampler_init(model, params.sampling);
n_threads = params.cpuparams.n_threads;

View File

@ -2024,10 +2024,10 @@ int main(int argc, char ** argv) {
llama_numa_init(params.numa);
// load the model and apply lora adapter, if any
common_init_result llama_init = common_init_from_params(params);
auto llama_init = common_init_from_params(params);
llama_model * model = llama_init.model.get();
llama_context * ctx = llama_init.context.get();
auto * model = llama_init->model();
auto * ctx = llama_init->context();
if (model == NULL) {
LOG_ERR("%s: unable to load model\n", __func__);

View File

@ -151,8 +151,7 @@ struct server_slot {
// sampling
json json_schema;
struct common_sampler * smpl = nullptr;
llama_sampler * backend_sampler = nullptr;
common_sampler_ptr smpl;
llama_token sampled;
@ -198,13 +197,6 @@ struct server_slot {
n_draft_total = 0;
n_draft_accepted = 0;
if (backend_sampler != nullptr) {
if (ctx != nullptr) {
llama_set_backend_sampler(ctx, id, nullptr);
}
backend_sampler = nullptr;
}
task.reset();
task_prev.reset();
@ -481,8 +473,8 @@ struct server_context {
common_params params_base;
// note: keep these alive - they determine the lifetime of the model, context, etc.
common_init_result llama_init;
common_init_result llama_init_dft;
common_init_result_ptr llama_init;
common_init_result_ptr llama_init_dft;
llama_model * model = nullptr;
llama_context * ctx = nullptr;
@ -526,16 +518,6 @@ struct server_context {
// Clear any sampling context
for (server_slot & slot : slots) {
common_sampler_free(slot.smpl);
slot.smpl = nullptr;
if (slot.backend_sampler != nullptr) {
if (ctx != nullptr) {
llama_set_backend_sampler(ctx, slot.id, nullptr);
}
slot.backend_sampler = nullptr;
}
llama_free(slot.ctx_dft);
slot.ctx_dft = nullptr;
@ -556,8 +538,8 @@ struct server_context {
llama_init = common_init_from_params(params_base);
model = llama_init.model.get();
ctx = llama_init.context.get();
model = llama_init->model();
ctx = llama_init->context();
if (model == nullptr) {
SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str());
@ -589,25 +571,25 @@ struct server_context {
llama_init_dft = common_init_from_params(params_dft);
model_dft = llama_init_dft.model.get();
model_dft = llama_init_dft->model();
if (model_dft == nullptr) {
SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str());
return false;
}
vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft.context.get());
vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft->context());
if (!vocab_dft_compatible) {
SRV_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str());
}
const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get());
const int n_ctx_dft = llama_n_ctx(llama_init_dft->context());
cparams_dft = common_context_params_to_llama(params_dft);
cparams_dft.n_batch = n_ctx_dft;
// the context is not needed - we will create one for each slot
llama_init_dft.context.reset();
llama_init_dft->free_context();
}
chat_templates = common_chat_templates_init(model, params_base.chat_template);
@ -1001,23 +983,17 @@ struct server_context {
// initialize samplers
{
if (slot.smpl != nullptr) {
common_sampler_free(slot.smpl);
}
slot.smpl = common_sampler_init(model, task.params.sampling);
slot.smpl.reset(common_sampler_init(model, task.params.sampling));
if (slot.smpl == nullptr) {
// for now, the only error that may happen here is invalid grammar
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
return false;
}
SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl).c_str());
}
SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str());
if (!configure_slot_backend_sampler(slot, task.params.sampling)) {
send_error(task, "Failed to configure backend samplers", ERROR_TYPE_SERVER);
return false;
llama_sampler * backend_chain = common_sampler_chain_backend(slot.smpl.get());
llama_set_backend_sampler(ctx, slot.id, backend_chain);
}
// initialize draft batch
@ -1037,39 +1013,6 @@ struct server_context {
return true;
}
bool configure_slot_backend_sampler(server_slot & slot, const common_params_sampling & sampling) {
if (!sampling.backend_sampling) {
if (slot.backend_sampler != nullptr) {
llama_set_backend_sampler(ctx, slot.id, nullptr);
slot.backend_sampler = nullptr;
}
return true;
}
llama_sampler * backend_chain = common_sampler_backend_init(model, sampling);
// The sampler types configured with --samplers might not be supported
// by backend samplers in which case we disable backend sampling and
// fallback to CPU only sampling.
if (backend_chain == nullptr) {
if (slot.backend_sampler != nullptr) {
llama_set_backend_sampler(ctx, slot.id, nullptr);
slot.backend_sampler = nullptr;
}
SLT_INF(slot, "%s", "no backend samplers configured (sampler chain doesn't start with backend-supported samplers)\n");
return true;
}
if (slot.backend_sampler != nullptr) {
llama_set_backend_sampler(ctx, slot.id, nullptr);
slot.backend_sampler = nullptr;
}
slot.backend_sampler = backend_chain;
llama_set_backend_sampler(ctx, slot.id, backend_chain);
SLT_INF(slot, "%s", "configured backend samplers\n");
return true;
}
bool process_token(completion_token_output & result, server_slot & slot) {
// remember which tokens were sampled - used for repetition penalties during sampling
const std::string token_str = result.text_to_send;
@ -1206,7 +1149,7 @@ struct server_context {
size_t n_vocab = llama_vocab_n_tokens(vocab);
if (post_sampling) {
const auto * cur_p = common_sampler_get_candidates(slot.smpl, true);
const auto * cur_p = common_sampler_get_candidates(slot.smpl.get(), true);
const size_t max_probs = cur_p->size;
// set probability for sampled token
@ -2185,13 +2128,13 @@ struct server_context {
GGML_ASSERT(batch.n_tokens > 0);
common_sampler_reset(slot.smpl);
common_sampler_reset(slot.smpl.get());
// Process all prompt tokens through sampler system
for (int i = 0; i < slot.task->n_tokens(); ++i) {
llama_token id = input_tokens[i];
if (id != LLAMA_TOKEN_NULL) {
common_sampler_accept(slot.smpl, id, false);
common_sampler_accept(slot.smpl.get(), id, false);
}
}
@ -2381,11 +2324,11 @@ struct server_context {
const int tok_idx = slot.i_batch - i;
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
llama_token id = common_sampler_sample(slot.smpl.get(), ctx, tok_idx);
slot.i_batch = -1;
common_sampler_accept(slot.smpl, id, true);
common_sampler_accept(slot.smpl.get(), id, true);
slot.n_decoded += 1;
@ -2488,7 +2431,7 @@ struct server_context {
llama_decode(ctx, slot.batch_spec);
// the accepted tokens from the speculation
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, draft);
slot.n_decoded += ids.size();

View File

@ -568,10 +568,10 @@ int main(int argc, char ** argv) {
llama_context * ctx_ttc = NULL;
llama_context * ctx_cts = NULL;
common_init_result llama_init_ttc = common_init_from_params(params);
auto llama_init_ttc = common_init_from_params(params);
model_ttc = llama_init_ttc.model.get();
ctx_ttc = llama_init_ttc.context.get();
model_ttc = llama_init_ttc->model();
ctx_ttc = llama_init_ttc->context();
if (model_ttc == nullptr || ctx_ttc == nullptr) {
return ENOENT;
@ -583,10 +583,10 @@ int main(int argc, char ** argv) {
params.embedding = true;
params.n_ubatch = params.n_batch;
common_init_result llama_init_cts = common_init_from_params(params);
auto llama_init_cts = common_init_from_params(params);
model_cts = llama_init_cts.model.get();
ctx_cts = llama_init_cts.context.get();
model_cts = llama_init_cts->model();
ctx_cts = llama_init_cts->context();
if (model_cts == nullptr || ctx_cts == nullptr) {
return ENOENT;