refactor : simplify and improve memory management
This commit is contained in:
parent
459b7ae7b9
commit
117e2079a9
|
|
@ -950,31 +950,40 @@ std::vector<common_file_info> fs_list_files(const std::string & path) {
|
|||
// Model utils
|
||||
//
|
||||
|
||||
static inline void common_init_sampler_from_model(
|
||||
// TODO: move to common/sampling
|
||||
static void common_init_sampler_from_model(
|
||||
const llama_model * model,
|
||||
common_params_sampling & sparams) {
|
||||
|
||||
const uint64_t config = sparams.user_sampling_config;
|
||||
|
||||
auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) {
|
||||
if (config & user_config) return;
|
||||
if (config & user_config) {
|
||||
return;
|
||||
}
|
||||
|
||||
char buf[64] = {0};
|
||||
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
|
||||
char * end = nullptr;
|
||||
int32_t v = strtol(buf, &end, 10);
|
||||
if (end && end != buf) dst = v;
|
||||
if (end && end != buf) {
|
||||
dst = v;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
auto get_float = [&](const char * key, float & dst, uint64_t user_config) {
|
||||
if (config & user_config) return;
|
||||
if (config & user_config) {
|
||||
return;
|
||||
}
|
||||
|
||||
char buf[128] = {0};
|
||||
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
|
||||
char * end = nullptr;
|
||||
float v = strtof(buf, &end);
|
||||
if (end && end != buf) dst = v;
|
||||
if (end && end != buf) {
|
||||
dst = v;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -1002,45 +1011,130 @@ static inline void common_init_sampler_from_model(
|
|||
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA);
|
||||
}
|
||||
|
||||
struct common_init_result common_init_from_params(common_params & params) {
|
||||
common_init_result iparams;
|
||||
auto mparams = common_model_params_to_llama(params);
|
||||
struct common_init_result::impl {
|
||||
impl() = default;
|
||||
~impl() = default;
|
||||
|
||||
llama_model_ptr model;
|
||||
llama_context_ptr context;
|
||||
|
||||
std::vector<llama_adapter_lora_ptr> lora;
|
||||
|
||||
std::vector<common_sampler_ptr> samplers;
|
||||
std::vector<llama_sampler_seq_config> samplers_seq_config;
|
||||
};
|
||||
|
||||
common_init_result::common_init_result(common_params & params) :
|
||||
pimpl(new impl{}) {
|
||||
const auto mparams = common_model_params_to_llama(params);
|
||||
|
||||
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
|
||||
if (model == NULL) {
|
||||
LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
|
||||
__func__, params.model.path.c_str());
|
||||
return iparams;
|
||||
return;
|
||||
}
|
||||
|
||||
common_init_sampler_from_model(model, params.sampling);
|
||||
pimpl->model.reset(model);
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
// updates params.sampling
|
||||
// TODO: fix naming
|
||||
common_init_sampler_from_model(model, params.sampling);
|
||||
|
||||
auto cparams = common_context_params_to_llama(params);
|
||||
|
||||
if (params.sampling.backend_sampling) {
|
||||
llama_sampler * backend_chain = common_sampler_backend_init(model, params.sampling);
|
||||
if (backend_chain != nullptr) {
|
||||
iparams.samplers_seq_config.resize(cparams.n_seq_max);
|
||||
for (int i = 0; i < (int) cparams.n_seq_max; ++i) {
|
||||
iparams.samplers_seq_config[i] = { i, llama_sampler_clone(backend_chain) };
|
||||
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
|
||||
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
|
||||
params.sampling.ignore_eos = false;
|
||||
}
|
||||
cparams.samplers = iparams.samplers_seq_config.data();
|
||||
cparams.n_samplers = cparams.n_seq_max;
|
||||
|
||||
llama_sampler_free(backend_chain);
|
||||
// initialize once
|
||||
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
|
||||
if (llama_vocab_is_eog(vocab, i)) {
|
||||
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(vocab, i).c_str(), -INFINITY);
|
||||
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
|
||||
}
|
||||
}
|
||||
|
||||
if (params.sampling.ignore_eos) {
|
||||
// add EOG biases to the active set of logit biases
|
||||
params.sampling.logit_bias.insert(
|
||||
params.sampling.logit_bias.end(),
|
||||
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
|
||||
}
|
||||
|
||||
//if (params.sampling.penalty_last_n == -1) {
|
||||
// LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
||||
// params.sampling.penalty_last_n = llama_n_ctx(lctx);
|
||||
//}
|
||||
|
||||
//if (params.sampling.dry_penalty_last_n == -1) {
|
||||
// LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
||||
// params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
|
||||
//}
|
||||
|
||||
// init the backend samplers as part of the context creation
|
||||
pimpl->samplers.resize(cparams.n_seq_max);
|
||||
pimpl->samplers_seq_config.resize(cparams.n_seq_max);
|
||||
|
||||
for (int i = 0; i < (int) cparams.n_seq_max; ++i) {
|
||||
pimpl->samplers[i].reset(common_sampler_init(model, params.sampling));
|
||||
llama_sampler * backend_chain = common_sampler_chain_backend(pimpl->samplers[i].get());
|
||||
pimpl->samplers_seq_config[i] = { i, backend_chain };
|
||||
}
|
||||
|
||||
cparams.samplers = pimpl->samplers_seq_config.data();
|
||||
cparams.n_samplers = pimpl->samplers_seq_config.size();
|
||||
|
||||
llama_context * lctx = llama_init_from_model(model, cparams);
|
||||
if (lctx == NULL) {
|
||||
LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
|
||||
__func__, params.model.path.c_str());
|
||||
llama_model_free(model);
|
||||
return iparams;
|
||||
return;
|
||||
}
|
||||
|
||||
pimpl->context.reset(lctx);
|
||||
}
|
||||
|
||||
llama_model * common_init_result::model() {
|
||||
return pimpl->model.get();
|
||||
}
|
||||
|
||||
llama_context * common_init_result::context() {
|
||||
return pimpl->context.get();
|
||||
}
|
||||
|
||||
common_sampler * common_init_result::sampler(llama_seq_id seq_id) {
|
||||
return pimpl->samplers[seq_id].get();
|
||||
}
|
||||
|
||||
std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
|
||||
return pimpl->lora;
|
||||
}
|
||||
|
||||
void common_init_result::free_context() {
|
||||
pimpl->context.reset();
|
||||
}
|
||||
|
||||
common_init_result_ptr common_init_from_params(common_params & params) {
|
||||
common_init_result_ptr res(new common_init_result(params));
|
||||
|
||||
llama_model * model = res->model();
|
||||
if (model == NULL) {
|
||||
LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
|
||||
__func__, params.model.path.c_str());
|
||||
return res;
|
||||
}
|
||||
|
||||
llama_context * lctx = res->context();
|
||||
if (lctx == NULL) {
|
||||
LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
|
||||
__func__, params.model.path.c_str());
|
||||
return res;
|
||||
}
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) {
|
||||
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
|
||||
params.ctx_shift = false;
|
||||
|
|
@ -1052,10 +1146,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||
|
||||
const auto cvec = common_control_vector_load(params.control_vectors);
|
||||
if (cvec.n_embd == -1) {
|
||||
llama_free(lctx);
|
||||
llama_model_free(model);
|
||||
|
||||
return iparams;
|
||||
return res;
|
||||
}
|
||||
|
||||
int err = llama_apply_adapter_cvec(
|
||||
|
|
@ -1066,10 +1157,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||
params.control_vector_layer_start,
|
||||
params.control_vector_layer_end);
|
||||
if (err) {
|
||||
llama_free(lctx);
|
||||
llama_model_free(model);
|
||||
|
||||
return iparams;
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1093,10 +1181,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||
}
|
||||
|
||||
if (!ok) {
|
||||
llama_free(lctx);
|
||||
llama_model_free(model);
|
||||
|
||||
return iparams;
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1106,9 +1191,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||
lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
|
||||
if (lora == nullptr) {
|
||||
LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
|
||||
llama_free(lctx);
|
||||
llama_model_free(model);
|
||||
return iparams;
|
||||
return res;
|
||||
}
|
||||
|
||||
char buf[1024];
|
||||
|
|
@ -1117,43 +1200,13 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||
la.task_name = buf;
|
||||
llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf));
|
||||
la.prompt_prefix = buf;
|
||||
iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
|
||||
res->lora().emplace_back(std::move(lora)); // copy to list of loaded adapters
|
||||
}
|
||||
|
||||
if (!params.lora_init_without_apply) {
|
||||
common_set_adapter_lora(lctx, params.lora_adapters);
|
||||
}
|
||||
|
||||
if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
|
||||
LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
|
||||
params.sampling.ignore_eos = false;
|
||||
}
|
||||
|
||||
// initialize once
|
||||
for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
|
||||
if (llama_vocab_is_eog(vocab, i)) {
|
||||
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
|
||||
params.sampling.logit_bias_eog.push_back({i, -INFINITY});
|
||||
}
|
||||
}
|
||||
|
||||
if (params.sampling.ignore_eos) {
|
||||
// add EOG biases to the active set of logit biases
|
||||
params.sampling.logit_bias.insert(
|
||||
params.sampling.logit_bias.end(),
|
||||
params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
|
||||
}
|
||||
|
||||
if (params.sampling.penalty_last_n == -1) {
|
||||
LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
||||
params.sampling.penalty_last_n = llama_n_ctx(lctx);
|
||||
}
|
||||
|
||||
if (params.sampling.dry_penalty_last_n == -1) {
|
||||
LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
|
||||
params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
|
||||
}
|
||||
|
||||
if (params.warmup) {
|
||||
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
|
||||
|
||||
|
|
@ -1192,12 +1245,11 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||
llama_set_warmup(lctx, false);
|
||||
}
|
||||
|
||||
iparams.model.reset(model);
|
||||
iparams.context.reset(lctx);
|
||||
|
||||
return iparams;
|
||||
return res;
|
||||
}
|
||||
|
||||
common_init_result::~common_init_result() = default;
|
||||
|
||||
std::string get_model_endpoint() {
|
||||
const char * model_endpoint_env = getenv("MODEL_ENDPOINT");
|
||||
// We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility.
|
||||
|
|
@ -1206,7 +1258,9 @@ std::string get_model_endpoint() {
|
|||
std::string model_endpoint = "https://huggingface.co/";
|
||||
if (endpoint_env) {
|
||||
model_endpoint = endpoint_env;
|
||||
if (model_endpoint.back() != '/') model_endpoint += '/';
|
||||
if (model_endpoint.back() != '/') {
|
||||
model_endpoint += '/';
|
||||
}
|
||||
}
|
||||
return model_endpoint;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -192,7 +192,6 @@ struct common_params_sampling {
|
|||
|
||||
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
|
||||
|
||||
|
||||
std::vector<enum common_sampler_type> samplers = {
|
||||
COMMON_SAMPLER_TYPE_PENALTIES,
|
||||
COMMON_SAMPLER_TYPE_DRY,
|
||||
|
|
@ -215,6 +214,16 @@ struct common_params_sampling {
|
|||
|
||||
bool backend_sampling = false; // enable backend sampling
|
||||
|
||||
bool has_logit_bias() const {
|
||||
return !logit_bias.empty();
|
||||
}
|
||||
|
||||
bool is_disabled(enum common_sampler_type type) const;
|
||||
|
||||
// remove disabled samplers
|
||||
// TODO: temporary until all samplers have llama_sampler_backend_ API [LLAMA_SAMPLER_BACKEND]
|
||||
void filter_disabled();
|
||||
|
||||
// print the parameters into a string
|
||||
std::string print() const;
|
||||
};
|
||||
|
|
@ -650,18 +659,29 @@ std::vector<common_file_info> fs_list_files(const std::string & path);
|
|||
// Model utils
|
||||
//
|
||||
|
||||
struct common_sampler;
|
||||
|
||||
// note: defines object's lifetime
|
||||
struct common_init_result {
|
||||
llama_model_ptr model;
|
||||
llama_context_ptr context;
|
||||
common_init_result(common_params & params);
|
||||
~common_init_result();
|
||||
|
||||
std::vector<llama_adapter_lora_ptr> lora;
|
||||
llama_model * model();
|
||||
llama_context * context();
|
||||
common_sampler * sampler(llama_seq_id seq_id);
|
||||
|
||||
std::vector<llama_sampler_ptr> samplers;
|
||||
std::vector<llama_sampler_seq_config> samplers_seq_config;
|
||||
std::vector<llama_adapter_lora_ptr> & lora();
|
||||
|
||||
void free_context();
|
||||
|
||||
private:
|
||||
struct impl;
|
||||
std::unique_ptr<impl> pimpl;
|
||||
};
|
||||
|
||||
struct common_init_result common_init_from_params(common_params & params);
|
||||
using common_init_result_ptr = std::unique_ptr<common_init_result>;
|
||||
|
||||
common_init_result_ptr common_init_from_params(common_params & params);
|
||||
|
||||
struct llama_model_params common_model_params_to_llama ( common_params & params);
|
||||
struct llama_context_params common_context_params_to_llama(const common_params & params);
|
||||
|
|
|
|||
|
|
@ -105,7 +105,8 @@ struct common_sampler {
|
|||
common_params_sampling params;
|
||||
|
||||
struct llama_sampler * grmr;
|
||||
struct llama_sampler * chain; // CPU sampling chain
|
||||
struct llama_sampler * chain;
|
||||
struct llama_sampler * chain_backend;
|
||||
|
||||
ring_buffer<llama_token> prev;
|
||||
|
||||
|
|
@ -118,6 +119,7 @@ struct common_sampler {
|
|||
|
||||
llama_sampler_reset(grmr);
|
||||
llama_sampler_reset(chain);
|
||||
llama_sampler_reset(chain_backend);
|
||||
}
|
||||
|
||||
void set_logits(struct llama_context * ctx, int idx) {
|
||||
|
|
@ -161,7 +163,8 @@ struct common_sampler {
|
|||
mutable int64_t t_total_us = 0;
|
||||
};
|
||||
|
||||
static bool sampler_backend_supported(enum common_sampler_type type) {
|
||||
// TODO: temporary until all samplers have llama_sampler_backend_ API [LLAMA_SAMPLER_BACKEND]
|
||||
static bool common_sampler_type_has_backend_support(enum common_sampler_type type) {
|
||||
switch (type) {
|
||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||
|
|
@ -172,98 +175,69 @@ static bool sampler_backend_supported(enum common_sampler_type type) {
|
|||
}
|
||||
}
|
||||
|
||||
static bool is_sampler_enabled(enum common_sampler_type type, const struct common_params_sampling & params) {
|
||||
bool common_params_sampling::is_disabled(enum common_sampler_type type) const {
|
||||
switch (type) {
|
||||
case COMMON_SAMPLER_TYPE_PENALTIES:
|
||||
if (params.penalty_last_n == 64 &&
|
||||
fabs(params.penalty_repeat) <= 1.0f &&
|
||||
fabs(params.penalty_freq) <= 0.0f &&
|
||||
fabs(params.penalty_present) <= 0.0f) {
|
||||
return false;
|
||||
if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_DRY:
|
||||
if (params.dry_multiplier == 0.0f && params.dry_base == 1.75f) {
|
||||
return false;
|
||||
if (dry_multiplier == 0.0f || dry_base < 1.0f || dry_penalty_last_n == 0) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
||||
if (params.typ_p == 1.0) {
|
||||
return false;
|
||||
if (typ_p >= 1.0) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_N_SIGMA:
|
||||
if (params.top_n_sigma == -1.0) {
|
||||
return false;
|
||||
if (top_n_sigma <= 0.0) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
if (params.top_k <= 0) {
|
||||
return false;
|
||||
if (top_k <= 0) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||
if (params.temp < 0.0f) {
|
||||
return false;
|
||||
if (dynatemp_range <= 0.0f) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||
if (params.min_p <= 0.0f) {
|
||||
return false;
|
||||
if (min_p <= 0.0f) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_P:
|
||||
if (params.top_p >= 1.0f) {
|
||||
return false;
|
||||
if (top_p >= 1.0f) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_XTC:
|
||||
if (params.xtc_probability == 0.0f && params.xtc_threshold == 0.10f) {
|
||||
return false;
|
||||
if (xtc_probability <= 0.0f || xtc_threshold == 0.50f) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return true;
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool has_logit_bias(const struct common_params_sampling & params) {
|
||||
return !params.logit_bias.empty();
|
||||
}
|
||||
|
||||
struct active_samplers {
|
||||
std::vector<common_sampler_type> backend_samplers;
|
||||
std::vector<common_sampler_type> cpu_samplers;
|
||||
};
|
||||
|
||||
static struct active_samplers get_active_samplers(const struct common_params_sampling & params) {
|
||||
struct active_samplers result;
|
||||
|
||||
if (params.mirostat != 0) {
|
||||
// Mirostat is CPU-only and overrides other samplers
|
||||
for (const auto & sampler_type : params.samplers) {
|
||||
if (is_sampler_enabled(sampler_type, params)) {
|
||||
result.cpu_samplers.push_back(sampler_type);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool backend_supported = params.backend_sampling;
|
||||
|
||||
for (const auto & sampler_type : params.samplers) {
|
||||
if (!is_sampler_enabled(sampler_type, params)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (backend_supported && sampler_backend_supported(sampler_type)) {
|
||||
result.backend_samplers.push_back(sampler_type);
|
||||
void common_params_sampling::filter_disabled() {
|
||||
for (auto it = samplers.begin(); it != samplers.end();) {
|
||||
if (is_disabled(*it)) {
|
||||
LOG_WRN("%s: removing disabled sampler %s\n", __func__, common_sampler_type_to_str(*it).c_str());
|
||||
it = samplers.erase(it);
|
||||
} else {
|
||||
result.cpu_samplers.push_back(sampler_type);
|
||||
++it;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string common_params_sampling::print() const {
|
||||
|
|
@ -282,15 +256,7 @@ std::string common_params_sampling::print() const {
|
|||
return std::string(result);
|
||||
}
|
||||
|
||||
struct backend_chain_data {
|
||||
struct llama_sampler * chain;
|
||||
size_t count;
|
||||
};
|
||||
|
||||
static struct backend_chain_data backend_samplers_init(const struct llama_model * model, const struct common_params_sampling & params,
|
||||
struct active_samplers get_active_samplers);
|
||||
|
||||
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
|
||||
struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params) {
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
|
||||
|
|
@ -357,29 +323,74 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
}
|
||||
}
|
||||
|
||||
// TODO: temporary until all samplers have llama_sampler_backend_ API [LLAMA_SAMPLER_BACKEND]
|
||||
if (params.backend_sampling) {
|
||||
params.filter_disabled();
|
||||
}
|
||||
|
||||
auto * result = new common_sampler {
|
||||
/* .params = */ params,
|
||||
/* .grmr = */ grmr,
|
||||
/* .chain = */ llama_sampler_chain_init(lparams),
|
||||
/* .chain_backend = */ llama_sampler_chain_init(lparams),
|
||||
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
||||
/* .cur = */ {},
|
||||
/* .cur_p = */ {},
|
||||
};
|
||||
|
||||
struct active_samplers active_samplers = get_active_samplers(params);
|
||||
size_t idx_smpl = 0;
|
||||
|
||||
// Build CPU chain
|
||||
if (!params.backend_sampling || !has_logit_bias(params)) {
|
||||
bool is_backend = true;
|
||||
|
||||
is_backend = is_backend && params.backend_sampling;
|
||||
is_backend = is_backend && (params.samplers.size() == 0 || common_sampler_type_has_backend_support(params.samplers[idx_smpl]));
|
||||
|
||||
if (params.has_logit_bias()) {
|
||||
if (is_backend) {
|
||||
llama_sampler_chain_add(result->chain_backend,
|
||||
llama_sampler_backend_init_logit_bias(
|
||||
llama_vocab_n_tokens(vocab),
|
||||
params.logit_bias.size(),
|
||||
params.logit_bias.data()));
|
||||
} else {
|
||||
llama_sampler_chain_add(result->chain,
|
||||
llama_sampler_init_logit_bias(
|
||||
llama_vocab_n_tokens(vocab),
|
||||
params.logit_bias.size(),
|
||||
params.logit_bias.data()));
|
||||
}
|
||||
}
|
||||
|
||||
if (params.mirostat == 0) {
|
||||
// backend samplers are added first
|
||||
while (is_backend && idx_smpl < params.samplers.size()) {
|
||||
const auto & cnstr = params.samplers[idx_smpl++];
|
||||
|
||||
if (!common_sampler_type_has_backend_support(cnstr)) {
|
||||
is_backend = false;
|
||||
--idx_smpl;
|
||||
break;
|
||||
}
|
||||
|
||||
switch (cnstr) {
|
||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
llama_sampler_chain_add(result->chain_backend, llama_sampler_backend_init_top_k(params.top_k));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||
llama_sampler_chain_add(result->chain_backend, llama_sampler_backend_init_temp(params.temp));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||
llama_sampler_chain_add(result->chain_backend, llama_sampler_backend_init_min_p(params.min_p));
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unsupported backend sampler");
|
||||
}
|
||||
}
|
||||
|
||||
// Add remaining CPU samplers
|
||||
for (const auto & cnstr : active_samplers.cpu_samplers) {
|
||||
while (idx_smpl < params.samplers.size()) {
|
||||
const auto & cnstr = params.samplers[idx_smpl++];
|
||||
|
||||
switch (cnstr) {
|
||||
case COMMON_SAMPLER_TYPE_DRY:
|
||||
{
|
||||
|
|
@ -424,7 +435,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
}
|
||||
}
|
||||
|
||||
if (!active_samplers.cpu_samplers.empty()) {
|
||||
if (is_backend) {
|
||||
llama_sampler_chain_add(result->chain_backend, llama_sampler_backend_init_dist(params.seed));
|
||||
} else {
|
||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
||||
}
|
||||
} else if (params.mirostat == 1) {
|
||||
|
|
@ -440,59 +453,11 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
return result;
|
||||
}
|
||||
|
||||
|
||||
static struct backend_chain_data backend_samplers_init(const struct llama_model * model, const struct common_params_sampling & params,
|
||||
struct active_samplers active_samplers) {
|
||||
if (active_samplers.backend_samplers.empty()) {
|
||||
return { nullptr, 0 };
|
||||
}
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
|
||||
chain_params.no_perf = params.no_perf;
|
||||
struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
|
||||
|
||||
// Add logit_bias to backend chain if present
|
||||
if (has_logit_bias(params)) {
|
||||
llama_sampler_chain_add(chain, llama_sampler_backend_init_logit_bias(
|
||||
llama_vocab_n_tokens(vocab),
|
||||
params.logit_bias.size(),
|
||||
params.logit_bias.data()));
|
||||
}
|
||||
|
||||
for (const auto & sampler_type : active_samplers.backend_samplers) {
|
||||
switch (sampler_type) {
|
||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
llama_sampler_chain_add(chain, llama_sampler_backend_init_top_k(params.top_k));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||
llama_sampler_chain_add(chain, llama_sampler_backend_init_temp(params.temp));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||
llama_sampler_chain_add(chain, llama_sampler_backend_init_min_p(params.min_p));
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unsupported backend sampler");
|
||||
}
|
||||
}
|
||||
|
||||
if (active_samplers.cpu_samplers.empty()) {
|
||||
llama_sampler_chain_add(chain, llama_sampler_backend_init_dist(params.seed));
|
||||
}
|
||||
|
||||
return { chain, active_samplers.backend_samplers.size() + has_logit_bias(params) };
|
||||
}
|
||||
|
||||
struct llama_sampler * common_sampler_backend_init(const struct llama_model * model, const struct common_params_sampling & params) {
|
||||
struct active_samplers active_samplers = get_active_samplers(params);
|
||||
return backend_samplers_init(model, params, active_samplers).chain;
|
||||
}
|
||||
|
||||
void common_sampler_free(struct common_sampler * gsmpl) {
|
||||
if (gsmpl) {
|
||||
llama_sampler_free(gsmpl->grmr);
|
||||
llama_sampler_free(gsmpl->chain);
|
||||
llama_sampler_free(gsmpl->chain_backend);
|
||||
|
||||
delete gsmpl;
|
||||
}
|
||||
|
|
@ -519,6 +484,7 @@ struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
|||
/* .params = */ gsmpl->params,
|
||||
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
|
||||
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
||||
/* .chain_backend = */ llama_sampler_clone(gsmpl->chain_backend),
|
||||
/* .prev = */ gsmpl->prev,
|
||||
/* .cur = */ gsmpl->cur,
|
||||
/* .cur_p = */ gsmpl->cur_p,
|
||||
|
|
@ -570,6 +536,10 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
|
|||
}
|
||||
}
|
||||
|
||||
struct llama_sampler * common_sampler_chain_backend(const struct common_sampler * gsmpl) {
|
||||
return gsmpl->chain_backend;
|
||||
}
|
||||
|
||||
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
|
||||
// Check if a backend sampler has already sampled a token in which case we
|
||||
// return that token id directly.
|
||||
|
|
@ -707,7 +677,12 @@ llama_token common_sampler_last(const struct common_sampler * gsmpl) {
|
|||
}
|
||||
|
||||
std::string common_sampler_print(const struct common_sampler * gsmpl) {
|
||||
std::string result = "logits ";
|
||||
std::string result = llama_sampler_chain_n(gsmpl->chain_backend) > 0 ? "*logits " : "logits ";
|
||||
|
||||
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain_backend); i++) {
|
||||
const auto * smpl = llama_sampler_chain_get(gsmpl->chain_backend, i);
|
||||
result += std::string("-> *") + llama_sampler_name(smpl) + " ";
|
||||
}
|
||||
|
||||
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
|
||||
const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
|
||||
|
|
|
|||
|
|
@ -36,14 +36,8 @@ struct common_sampler;
|
|||
|
||||
// llama_sampler API overloads
|
||||
|
||||
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params);
|
||||
|
||||
// Create a backend sampler chain from common sampling parameters
|
||||
// Returns a llama_sampler chain configured with backend samplers based on the parameters
|
||||
// This chain can be used per-sequence for backend-based sampling
|
||||
// Note: Only samplers that have backend equivalents will be added to the chain
|
||||
// The returned sampler should be freed with llama_sampler_free()
|
||||
struct llama_sampler * common_sampler_backend_init(const struct llama_model * model, const struct common_params_sampling & params);
|
||||
// TODO: params should become const again [LLAMA_SAMPLER_BACKEND]
|
||||
struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params);
|
||||
|
||||
void common_sampler_free(struct common_sampler * gsmpl);
|
||||
|
||||
|
|
@ -55,6 +49,8 @@ struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);
|
|||
// arguments can be nullptr to skip printing
|
||||
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl);
|
||||
|
||||
struct llama_sampler * common_sampler_chain_backend(const struct common_sampler * gsmpl);
|
||||
|
||||
// extended sampling implementation:
|
||||
//
|
||||
// - set logits
|
||||
|
|
@ -114,3 +110,9 @@ std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std:
|
|||
|
||||
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
|
||||
const char * grammar_kind, const char * grammar_data);
|
||||
|
||||
struct common_sampler_deleter {
|
||||
void operator()(common_sampler * s) { common_sampler_free(s); }
|
||||
};
|
||||
|
||||
typedef std::unique_ptr<common_sampler, common_sampler_deleter> common_sampler_ptr;
|
||||
|
|
|
|||
|
|
@ -65,29 +65,34 @@ int main(int argc, char ** argv) {
|
|||
ctx_params.n_ctx = n_kv_req;
|
||||
ctx_params.n_batch = std::max(n_predict, n_parallel);
|
||||
|
||||
std::vector<llama_sampler_seq_config> sampler_configs(n_parallel);
|
||||
if (params.sampling.backend_sampling) {
|
||||
for (int32_t i = 0; i < n_parallel; ++i) {
|
||||
llama_sampler * backend_sampler = common_sampler_backend_init(model, params.sampling);
|
||||
if (backend_sampler) {
|
||||
sampler_configs[i] = { i, backend_sampler };
|
||||
}
|
||||
}
|
||||
ctx_params.samplers = sampler_configs.data();
|
||||
ctx_params.n_samplers = n_parallel;
|
||||
}
|
||||
|
||||
llama_context * ctx = llama_init_from_model(model, ctx_params);
|
||||
|
||||
auto sparams = llama_sampler_chain_default_params();
|
||||
sparams.no_perf = false;
|
||||
|
||||
std::vector<llama_sampler_seq_config> sampler_configs;
|
||||
|
||||
for (int32_t i = 0; i < n_parallel; ++i) {
|
||||
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
||||
|
||||
if (params.sampling.backend_sampling) {
|
||||
llama_sampler_chain_add(smpl, llama_sampler_backend_init_top_k(params.sampling.top_k));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_backend_init_temp (params.sampling.temp));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_backend_init_dist (params.sampling.seed));
|
||||
} else {
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed));
|
||||
}
|
||||
|
||||
sampler_configs.push_back({ i, smpl });
|
||||
}
|
||||
|
||||
if (params.sampling.backend_sampling) {
|
||||
ctx_params.samplers = sampler_configs.data();
|
||||
ctx_params.n_samplers = sampler_configs.size();
|
||||
}
|
||||
|
||||
llama_context * ctx = llama_init_from_model(model, ctx_params);
|
||||
|
||||
if (ctx == NULL) {
|
||||
LOG_ERR("%s: error: failed to create the llama_context\n" , __func__);
|
||||
|
|
@ -186,7 +191,7 @@ int main(int argc, char ** argv) {
|
|||
continue;
|
||||
}
|
||||
|
||||
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]);
|
||||
const llama_token new_token_id = llama_sampler_sample(sampler_configs[i].sampler, ctx, i_batch[i]);
|
||||
|
||||
// is it an end of generation? -> mark the stream as finished
|
||||
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_predict) {
|
||||
|
|
@ -242,14 +247,17 @@ int main(int argc, char ** argv) {
|
|||
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
|
||||
|
||||
LOG("\n");
|
||||
llama_perf_sampler_print(smpl);
|
||||
llama_perf_sampler_print(sampler_configs[0].sampler);
|
||||
llama_perf_context_print(ctx);
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
llama_sampler_free(smpl);
|
||||
for (auto & sampler_config : sampler_configs) {
|
||||
llama_sampler_free(sampler_config.sampler);
|
||||
}
|
||||
|
||||
llama_free(ctx);
|
||||
llama_model_free(model);
|
||||
|
||||
|
|
|
|||
|
|
@ -130,10 +130,10 @@ int main(int argc, char ** argv) {
|
|||
llama_numa_init(params.numa);
|
||||
|
||||
// load the model
|
||||
common_init_result llama_init = common_init_from_params(params);
|
||||
auto llama_init = common_init_from_params(params);
|
||||
|
||||
llama_model * model = llama_init.model.get();
|
||||
llama_context * ctx = llama_init.context.get();
|
||||
auto * model = llama_init->model();
|
||||
auto * ctx = llama_init->context();
|
||||
|
||||
if (model == NULL) {
|
||||
LOG_ERR("%s: unable to load model\n", __func__);
|
||||
|
|
|
|||
|
|
@ -202,10 +202,10 @@ int main(int argc, char ** argv) {
|
|||
params.warmup = false;
|
||||
|
||||
// init
|
||||
common_init_result llama_init = common_init_from_params(params);
|
||||
auto llama_init = common_init_from_params(params);
|
||||
|
||||
llama_model * model = llama_init.model.get();
|
||||
llama_context * ctx = llama_init.context.get();
|
||||
auto * model = llama_init->model();
|
||||
auto * ctx = llama_init->context();
|
||||
|
||||
if (model == nullptr || ctx == nullptr) {
|
||||
LOG_ERR("%s : failed to init\n", __func__);
|
||||
|
|
|
|||
|
|
@ -55,10 +55,10 @@ int main(int argc, char ** argv) {
|
|||
llama_numa_init(params.numa);
|
||||
|
||||
// load the target model
|
||||
common_init_result llama_init = common_init_from_params(params);
|
||||
auto llama_init = common_init_from_params(params);
|
||||
|
||||
llama_model * model = llama_init.model.get();
|
||||
llama_context * ctx = llama_init.context.get();
|
||||
auto * model = llama_init->model();
|
||||
auto * ctx = llama_init->context();
|
||||
|
||||
auto * mem = llama_get_memory(ctx);
|
||||
|
||||
|
|
|
|||
|
|
@ -18,16 +18,16 @@ int main(int argc, char ** argv){
|
|||
llama_numa_init(params.numa);
|
||||
|
||||
// load the model
|
||||
common_init_result llama_init = common_init_from_params(params);
|
||||
auto llama_init = common_init_from_params(params);
|
||||
|
||||
llama_model_ptr & model = llama_init.model;
|
||||
llama_context_ptr & ctx = llama_init.context;
|
||||
auto * model = llama_init->model();
|
||||
auto * ctx = llama_init->context();
|
||||
|
||||
GGML_ASSERT(model != nullptr);
|
||||
|
||||
// tokenize the prompt
|
||||
std::vector<llama_token> inp;
|
||||
inp = common_tokenize(ctx.get(), params.prompt, true, true);
|
||||
inp = common_tokenize(ctx, params.prompt, true, true);
|
||||
fprintf(stderr, "%s: tokenization done\n", __func__);
|
||||
|
||||
common_ngram_cache ngram_cache;
|
||||
|
|
|
|||
|
|
@ -28,13 +28,13 @@ int main(int argc, char ** argv){
|
|||
llama_numa_init(params.numa);
|
||||
|
||||
// load the model
|
||||
common_init_result llama_init = common_init_from_params(params);
|
||||
auto llama_init = common_init_from_params(params);
|
||||
|
||||
llama_context_ptr & ctx = llama_init.context;
|
||||
llama_context * ctx = llama_init->context();
|
||||
|
||||
// tokenize the prompt
|
||||
std::vector<llama_token> inp;
|
||||
inp = common_tokenize(ctx.get(), params.prompt, true, true);
|
||||
inp = common_tokenize(ctx, params.prompt, true, true);
|
||||
|
||||
common_ngram_cache ngram_cache_context;
|
||||
common_ngram_cache ngram_cache_dynamic;
|
||||
|
|
@ -65,7 +65,7 @@ int main(int argc, char ** argv){
|
|||
}
|
||||
|
||||
const int n_input = inp.size();
|
||||
const int n_ctx = llama_n_ctx(ctx.get());
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
|
||||
int n_drafted = 0;
|
||||
int n_accept = 0;
|
||||
|
|
|
|||
|
|
@ -29,10 +29,10 @@ int main(int argc, char ** argv){
|
|||
llama_numa_init(params.numa);
|
||||
|
||||
// load the model
|
||||
common_init_result llama_init = common_init_from_params(params);
|
||||
auto llama_init = common_init_from_params(params);
|
||||
|
||||
llama_model * model = llama_init.model.get();
|
||||
llama_context * ctx = llama_init.context.get();
|
||||
auto * model = llama_init->model();
|
||||
auto * ctx = llama_init->context();
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
|
|
|
|||
|
|
@ -192,10 +192,10 @@ int main(int argc, char ** argv) {
|
|||
llama_numa_init(params.numa);
|
||||
|
||||
// load the target model
|
||||
common_init_result llama_init = common_init_from_params(params);
|
||||
auto llama_init = common_init_from_params(params);
|
||||
|
||||
llama_model * model = llama_init.model.get();
|
||||
llama_context * ctx = llama_init.context.get();
|
||||
auto * model = llama_init->model();
|
||||
auto * ctx = llama_init->context();
|
||||
|
||||
auto * mem = llama_get_memory(ctx);
|
||||
|
||||
|
|
|
|||
|
|
@ -149,10 +149,10 @@ int main(int argc, char ** argv) {
|
|||
llama_numa_init(params.numa);
|
||||
|
||||
// load the model
|
||||
common_init_result llama_init = common_init_from_params(params);
|
||||
auto llama_init = common_init_from_params(params);
|
||||
|
||||
llama_model * model = llama_init.model.get();
|
||||
llama_context * ctx = llama_init.context.get();
|
||||
auto * model = llama_init->model();
|
||||
auto * ctx = llama_init->context();
|
||||
|
||||
if (model == NULL) {
|
||||
LOG_ERR("%s: unable to load model\n", __func__);
|
||||
|
|
|
|||
|
|
@ -34,10 +34,10 @@ int main(int argc, char ** argv) {
|
|||
std::string result2;
|
||||
|
||||
// init
|
||||
common_init_result llama_init = common_init_from_params(params);
|
||||
auto llama_init = common_init_from_params(params);
|
||||
|
||||
llama_model * model = llama_init.model.get();
|
||||
llama_context * ctx = llama_init.context.get();
|
||||
auto * model = llama_init->model();
|
||||
auto * ctx = llama_init->context();
|
||||
|
||||
if (model == nullptr || ctx == nullptr) {
|
||||
fprintf(stderr, "%s : failed to init\n", __func__);
|
||||
|
|
|
|||
|
|
@ -40,10 +40,10 @@ int main(int argc, char ** argv) {
|
|||
llama_context * ctx_dft = NULL;
|
||||
|
||||
// load the target model
|
||||
common_init_result llama_init_tgt = common_init_from_params(params);
|
||||
auto llama_init_tgt = common_init_from_params(params);
|
||||
|
||||
model_tgt = llama_init_tgt.model.get();
|
||||
ctx_tgt = llama_init_tgt.context.get();
|
||||
model_tgt = llama_init_tgt->model();
|
||||
ctx_tgt = llama_init_tgt->context();
|
||||
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model_tgt);
|
||||
|
||||
|
|
@ -61,10 +61,10 @@ int main(int argc, char ** argv) {
|
|||
params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
|
||||
params.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
|
||||
|
||||
common_init_result llama_init_dft = common_init_from_params(params);
|
||||
auto llama_init_dft = common_init_from_params(params);
|
||||
|
||||
//model_dft = llama_init_dft.model.get();
|
||||
ctx_dft = llama_init_dft.context.get();
|
||||
//model_dft = llama_init_dft->model();
|
||||
ctx_dft = llama_init_dft->context();
|
||||
|
||||
if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) {
|
||||
LOG_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params.speculative.model.path.c_str(), params.model.path.c_str());
|
||||
|
|
|
|||
|
|
@ -71,10 +71,10 @@ int main(int argc, char ** argv) {
|
|||
llama_context * ctx_dft = NULL;
|
||||
|
||||
// load the target model
|
||||
common_init_result llama_init_tgt = common_init_from_params(params);
|
||||
auto llama_init_tgt = common_init_from_params(params);
|
||||
|
||||
model_tgt = llama_init_tgt.model.get();
|
||||
ctx_tgt = llama_init_tgt.context.get();
|
||||
model_tgt = llama_init_tgt->model();
|
||||
ctx_tgt = llama_init_tgt->context();
|
||||
|
||||
// load the draft model
|
||||
params.devices = params.speculative.devices;
|
||||
|
|
@ -87,10 +87,10 @@ int main(int argc, char ** argv) {
|
|||
params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
|
||||
params.tensor_buft_overrides = params.speculative.tensor_buft_overrides;
|
||||
|
||||
common_init_result llama_init_dft = common_init_from_params(params);
|
||||
auto llama_init_dft = common_init_from_params(params);
|
||||
|
||||
model_dft = llama_init_dft.model.get();
|
||||
ctx_dft = llama_init_dft.context.get();
|
||||
model_dft = llama_init_dft->model();
|
||||
ctx_dft = llama_init_dft->context();
|
||||
|
||||
const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
|
||||
const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
|
||||
|
|
|
|||
|
|
@ -39,9 +39,10 @@ int main(int argc, char ** argv) {
|
|||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
// load the model and apply lora adapter, if any
|
||||
common_init_result llama_init = common_init_from_params(params);
|
||||
llama_model_ptr & model = llama_init.model;
|
||||
llama_context_ptr & ctx = llama_init.context;
|
||||
auto llama_init = common_init_from_params(params);
|
||||
|
||||
auto * model = llama_init->model();
|
||||
auto * ctx = llama_init->context();
|
||||
|
||||
if (model == NULL) {
|
||||
LOG_ERR("%s: unable to load model\n", __func__);
|
||||
|
|
@ -54,8 +55,8 @@ int main(int argc, char ** argv) {
|
|||
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
|
||||
}
|
||||
|
||||
std::vector<llama_token> tokens = common_tokenize(ctx.get(), params.prompt, true);
|
||||
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get()) / 2);
|
||||
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, true);
|
||||
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx, tokens, llama_n_ctx(ctx) / 2);
|
||||
|
||||
struct lr_opt & lr = params.lr;
|
||||
LOG_INF("-optimizer %s -lr0 %.2g -wd %.2g -lr-min %.2g -min-epochs %.2g -epochs %d -period %.2g -val %.2g\n",
|
||||
|
|
@ -70,7 +71,7 @@ int main(int argc, char ** argv) {
|
|||
/*get_opt_pars_ud =*/¶ms.lr,
|
||||
/*optimizer_type =*/params.optimizer,
|
||||
};
|
||||
llama_opt_init(ctx.get(), model.get(), lopt_params);
|
||||
llama_opt_init(ctx, model, lopt_params);
|
||||
|
||||
const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - params.val_split);
|
||||
|
||||
|
|
@ -78,7 +79,7 @@ int main(int argc, char ** argv) {
|
|||
ggml_opt_result_t result_eval = ggml_opt_result_init();
|
||||
|
||||
for (lr.epoch = 0; lr.epoch < lr.epochs; ++lr.epoch) {
|
||||
llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split,
|
||||
llama_opt_epoch(ctx, dataset, result_train, result_eval, idata_split,
|
||||
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
|
|
@ -88,7 +89,7 @@ int main(int argc, char ** argv) {
|
|||
ggml_opt_result_free(result_train);
|
||||
ggml_opt_result_free(result_eval);
|
||||
|
||||
llama_model_save_to_file(model.get(), params.out_file.c_str());
|
||||
llama_model_save_to_file(model, params.out_file.c_str());
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
|
|
|
|||
|
|
@ -376,7 +376,7 @@ extern "C" {
|
|||
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
|
||||
|
||||
// backend sampler chain configuration
|
||||
// backend sampler chain configuration (does not keep a reference, so make sure the caller keeps the samplers alive)
|
||||
struct llama_sampler_seq_config * samplers;
|
||||
size_t n_samplers;
|
||||
};
|
||||
|
|
@ -1209,10 +1209,6 @@ extern "C" {
|
|||
|
||||
void (*init_ggml)(struct llama_sampler * smpl,
|
||||
ggml_backend_buffer_type_t buft);
|
||||
|
||||
|
||||
// TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph
|
||||
//void (*apply_ggml) (struct llama_sampler * smpl, ...);
|
||||
};
|
||||
|
||||
struct llama_sampler {
|
||||
|
|
@ -1231,15 +1227,14 @@ extern "C" {
|
|||
LLAMA_API struct llama_sampler * llama_sampler_clone (const struct llama_sampler * smpl);
|
||||
// important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add)
|
||||
LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl);
|
||||
LLAMA_API void llama_sampler_init_ggml(struct llama_sampler * smpl,
|
||||
ggml_backend_buffer_type_t buft);
|
||||
|
||||
LLAMA_API void llama_sampler_init_ggml (struct llama_sampler * smpl, ggml_backend_buffer_type_t buft);
|
||||
LLAMA_API void llama_sampler_set_input_ggml(struct llama_sampler * smpl);
|
||||
LLAMA_API void llama_sampler_apply_ggml( struct llama_sampler * smpl,
|
||||
LLAMA_API void llama_sampler_apply_ggml (struct llama_sampler * smpl,
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * gf,
|
||||
struct llama_sampler_ggml_data * ggml_data);
|
||||
|
||||
LLAMA_API void llama_sampler_accept_ggml( struct llama_sampler * smpl,
|
||||
LLAMA_API void llama_sampler_accept_ggml (struct llama_sampler * smpl,
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * gf,
|
||||
struct ggml_tensor * selected_token);
|
||||
|
|
|
|||
|
|
@ -66,7 +66,15 @@ llama_context::llama_context(
|
|||
|
||||
for (size_t i = 0; i < params.n_samplers; ++i) {
|
||||
const auto & config = params.samplers[i];
|
||||
|
||||
const int n_samplers = llama_sampler_chain_n(config.sampler);
|
||||
if (n_samplers <= 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
sampling.samplers[config.seq_id] = config.sampler;
|
||||
|
||||
LLAMA_LOG_INFO("%s: setting backend sampler for seq_id %d (n = %d)\n", __func__, config.seq_id, n_samplers);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -438,8 +446,8 @@ llama_context::llama_context(
|
|||
|
||||
// Initialize the full vocabulary token ids for backend samplers.
|
||||
{
|
||||
const llama_vocab * vocab = llama_model_get_vocab(&model);
|
||||
const int n_vocab = llama_vocab_n_tokens(vocab);
|
||||
const int n_vocab = model.vocab.n_tokens();
|
||||
|
||||
sampling.token_ids_full_vocab.resize(n_vocab);
|
||||
for (int i = 0; i < n_vocab; ++i) {
|
||||
sampling.token_ids_full_vocab[i] = i;
|
||||
|
|
@ -449,10 +457,6 @@ llama_context::llama_context(
|
|||
|
||||
llama_context::~llama_context() {
|
||||
ggml_opt_free(opt_ctx);
|
||||
// TODO: perhaps use a smart pointer for samplers
|
||||
for (auto const& [seq_id, sampler] : sampling.samplers) {
|
||||
llama_sampler_free(sampler);
|
||||
}
|
||||
}
|
||||
|
||||
void llama_context::synchronize() {
|
||||
|
|
@ -910,31 +914,10 @@ void llama_context::set_warmup(bool value) {
|
|||
void llama_context::set_backend_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
|
||||
LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler);
|
||||
|
||||
auto it = sampling.samplers.find(seq_id);
|
||||
if (it != sampling.samplers.end()) {
|
||||
// If the sampler to be set is the same that is already set, do nothing.
|
||||
if (it->second == sampler) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_sampler_free(it->second);
|
||||
|
||||
// If sampler is nullptr, we remove the samppler chain for this seq_id.
|
||||
// chain for this seq_id.
|
||||
if (sampler == nullptr) {
|
||||
sampling.samplers.erase(it);
|
||||
return;
|
||||
}
|
||||
|
||||
// Otherwise, we replace the existing sampler with the new one.
|
||||
it->second = sampler;
|
||||
return;
|
||||
}
|
||||
|
||||
// If there is no sampler for this seq_id and the caller provides a non-null
|
||||
// sampler, we set it.
|
||||
if (sampler != nullptr) {
|
||||
if (sampler != nullptr && llama_sampler_chain_n(sampler) > 0) {
|
||||
sampling.samplers[seq_id] = sampler;
|
||||
} else {
|
||||
sampling.samplers.erase(seq_id);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -258,7 +258,7 @@ private:
|
|||
float * logits = nullptr;
|
||||
|
||||
struct sampling_info {
|
||||
std::unordered_map<llama_seq_id, llama_sampler*> samplers;
|
||||
std::unordered_map<llama_seq_id, llama_sampler *> samplers;
|
||||
|
||||
float * logits = nullptr;
|
||||
size_t logits_size = 0;
|
||||
|
|
|
|||
|
|
@ -439,9 +439,9 @@ void llama_sampler_free(struct llama_sampler * smpl) {
|
|||
}
|
||||
|
||||
llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
|
||||
const llama_token sampled_token = llama_get_backend_sampled_token_ith(ctx, idx);
|
||||
const float * sampled_probs = llama_get_backend_sampled_probs_ith(ctx, idx);
|
||||
const float * sampled_logits = llama_get_backend_sampled_logits_ith(ctx, idx);
|
||||
const llama_token sampled_token = llama_get_backend_sampled_token_ith (ctx, idx);
|
||||
const float * sampled_probs = llama_get_backend_sampled_probs_ith (ctx, idx);
|
||||
const float * sampled_logits = llama_get_backend_sampled_logits_ith (ctx, idx);
|
||||
const llama_token * sampled_ids = llama_get_backend_sampled_candidates_ith(ctx, idx);
|
||||
|
||||
// If a backend sampler has already sampled a token, return it.
|
||||
|
|
|
|||
|
|
@ -419,10 +419,10 @@ int main(int argc, char ** argv) {
|
|||
llama_numa_init(params.numa);
|
||||
|
||||
// load the model to get hparams
|
||||
common_init_result llama_init = common_init_from_params(params);
|
||||
auto llama_init = common_init_from_params(params);
|
||||
|
||||
llama_model * model = llama_init.model.get();
|
||||
llama_context * ctx = llama_init.context.get();
|
||||
auto * model = llama_init->model();
|
||||
auto * ctx = llama_init->context();
|
||||
|
||||
// int n_ctx = llama_n_ctx(ctx);
|
||||
int n_layers = llama_model_n_layer(model);
|
||||
|
|
|
|||
|
|
@ -1265,10 +1265,10 @@ int main(int argc, char ** argv) {
|
|||
params.warmup = false;
|
||||
|
||||
// init
|
||||
common_init_result llama_init = common_init_from_params(params);
|
||||
auto llama_init = common_init_from_params(params);
|
||||
|
||||
llama_model * model = llama_init.model.get();
|
||||
llama_context * ctx = llama_init.context.get();
|
||||
auto * model = llama_init->model();
|
||||
auto * ctx = llama_init->context();
|
||||
|
||||
if (model == nullptr || ctx == nullptr) {
|
||||
LOG_ERR("%s : failed to init\n", __func__);
|
||||
|
|
|
|||
|
|
@ -138,9 +138,11 @@ int main(int argc, char ** argv) {
|
|||
// load the model and apply lora adapter, if any
|
||||
LOG_INF("%s: load the model and apply lora adapter, if any\n", __func__);
|
||||
|
||||
common_init_result llama_init = common_init_from_params(params);
|
||||
ctx = llama_init.context.get();
|
||||
model = llama_init.model.get(); // Update pointer (now managed by llama_init)
|
||||
auto llama_init = common_init_from_params(params);
|
||||
|
||||
ctx = llama_init->context();
|
||||
model = llama_init->model();
|
||||
smpl = llama_init->sampler(0);
|
||||
|
||||
if (ctx == NULL) {
|
||||
LOG_ERR("%s: error: unable to create context\n", __func__);
|
||||
|
|
@ -470,12 +472,6 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
|
||||
smpl = common_sampler_init(model, sparams);
|
||||
if (!smpl) {
|
||||
LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
LOG_INF("sampler seed: %u\n", common_sampler_get_seed(smpl));
|
||||
LOG_INF("sampler params: \n%s\n", sparams.print().c_str());
|
||||
LOG_INF("sampler chain: %s\n", common_sampler_print(smpl).c_str());
|
||||
|
|
@ -989,8 +985,6 @@ int main(int argc, char ** argv) {
|
|||
LOG("\n\n");
|
||||
common_perf_print(ctx, smpl);
|
||||
|
||||
common_sampler_free(smpl);
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
ggml_threadpool_free_fn(threadpool);
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ static void sigint_handler(int signo) {
|
|||
|
||||
struct mtmd_cli_context {
|
||||
mtmd::context_ptr ctx_vision;
|
||||
common_init_result llama_init;
|
||||
common_init_result_ptr llama_init;
|
||||
|
||||
llama_model * model;
|
||||
llama_context * lctx;
|
||||
|
|
@ -89,8 +89,8 @@ struct mtmd_cli_context {
|
|||
llama_pos n_past = 0;
|
||||
|
||||
mtmd_cli_context(common_params & params) : llama_init(common_init_from_params(params)) {
|
||||
model = llama_init.model.get();
|
||||
lctx = llama_init.context.get();
|
||||
model = llama_init->model();
|
||||
lctx = llama_init->context();
|
||||
vocab = llama_model_get_vocab(model);
|
||||
smpl = common_sampler_init(model, params.sampling);
|
||||
n_threads = params.cpuparams.n_threads;
|
||||
|
|
|
|||
|
|
@ -2024,10 +2024,10 @@ int main(int argc, char ** argv) {
|
|||
llama_numa_init(params.numa);
|
||||
|
||||
// load the model and apply lora adapter, if any
|
||||
common_init_result llama_init = common_init_from_params(params);
|
||||
auto llama_init = common_init_from_params(params);
|
||||
|
||||
llama_model * model = llama_init.model.get();
|
||||
llama_context * ctx = llama_init.context.get();
|
||||
auto * model = llama_init->model();
|
||||
auto * ctx = llama_init->context();
|
||||
|
||||
if (model == NULL) {
|
||||
LOG_ERR("%s: unable to load model\n", __func__);
|
||||
|
|
|
|||
|
|
@ -151,8 +151,7 @@ struct server_slot {
|
|||
// sampling
|
||||
json json_schema;
|
||||
|
||||
struct common_sampler * smpl = nullptr;
|
||||
llama_sampler * backend_sampler = nullptr;
|
||||
common_sampler_ptr smpl;
|
||||
|
||||
llama_token sampled;
|
||||
|
||||
|
|
@ -198,13 +197,6 @@ struct server_slot {
|
|||
n_draft_total = 0;
|
||||
n_draft_accepted = 0;
|
||||
|
||||
if (backend_sampler != nullptr) {
|
||||
if (ctx != nullptr) {
|
||||
llama_set_backend_sampler(ctx, id, nullptr);
|
||||
}
|
||||
backend_sampler = nullptr;
|
||||
}
|
||||
|
||||
task.reset();
|
||||
task_prev.reset();
|
||||
|
||||
|
|
@ -481,8 +473,8 @@ struct server_context {
|
|||
common_params params_base;
|
||||
|
||||
// note: keep these alive - they determine the lifetime of the model, context, etc.
|
||||
common_init_result llama_init;
|
||||
common_init_result llama_init_dft;
|
||||
common_init_result_ptr llama_init;
|
||||
common_init_result_ptr llama_init_dft;
|
||||
|
||||
llama_model * model = nullptr;
|
||||
llama_context * ctx = nullptr;
|
||||
|
|
@ -526,16 +518,6 @@ struct server_context {
|
|||
|
||||
// Clear any sampling context
|
||||
for (server_slot & slot : slots) {
|
||||
common_sampler_free(slot.smpl);
|
||||
slot.smpl = nullptr;
|
||||
|
||||
if (slot.backend_sampler != nullptr) {
|
||||
if (ctx != nullptr) {
|
||||
llama_set_backend_sampler(ctx, slot.id, nullptr);
|
||||
}
|
||||
slot.backend_sampler = nullptr;
|
||||
}
|
||||
|
||||
llama_free(slot.ctx_dft);
|
||||
slot.ctx_dft = nullptr;
|
||||
|
||||
|
|
@ -556,8 +538,8 @@ struct server_context {
|
|||
|
||||
llama_init = common_init_from_params(params_base);
|
||||
|
||||
model = llama_init.model.get();
|
||||
ctx = llama_init.context.get();
|
||||
model = llama_init->model();
|
||||
ctx = llama_init->context();
|
||||
|
||||
if (model == nullptr) {
|
||||
SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str());
|
||||
|
|
@ -589,25 +571,25 @@ struct server_context {
|
|||
|
||||
llama_init_dft = common_init_from_params(params_dft);
|
||||
|
||||
model_dft = llama_init_dft.model.get();
|
||||
model_dft = llama_init_dft->model();
|
||||
|
||||
if (model_dft == nullptr) {
|
||||
SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft.context.get());
|
||||
vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft->context());
|
||||
if (!vocab_dft_compatible) {
|
||||
SRV_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str());
|
||||
}
|
||||
|
||||
const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get());
|
||||
const int n_ctx_dft = llama_n_ctx(llama_init_dft->context());
|
||||
|
||||
cparams_dft = common_context_params_to_llama(params_dft);
|
||||
cparams_dft.n_batch = n_ctx_dft;
|
||||
|
||||
// the context is not needed - we will create one for each slot
|
||||
llama_init_dft.context.reset();
|
||||
llama_init_dft->free_context();
|
||||
}
|
||||
|
||||
chat_templates = common_chat_templates_init(model, params_base.chat_template);
|
||||
|
|
@ -1001,23 +983,17 @@ struct server_context {
|
|||
|
||||
// initialize samplers
|
||||
{
|
||||
if (slot.smpl != nullptr) {
|
||||
common_sampler_free(slot.smpl);
|
||||
}
|
||||
|
||||
slot.smpl = common_sampler_init(model, task.params.sampling);
|
||||
slot.smpl.reset(common_sampler_init(model, task.params.sampling));
|
||||
if (slot.smpl == nullptr) {
|
||||
// for now, the only error that may happen here is invalid grammar
|
||||
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
}
|
||||
|
||||
SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl).c_str());
|
||||
}
|
||||
SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str());
|
||||
|
||||
if (!configure_slot_backend_sampler(slot, task.params.sampling)) {
|
||||
send_error(task, "Failed to configure backend samplers", ERROR_TYPE_SERVER);
|
||||
return false;
|
||||
llama_sampler * backend_chain = common_sampler_chain_backend(slot.smpl.get());
|
||||
llama_set_backend_sampler(ctx, slot.id, backend_chain);
|
||||
}
|
||||
|
||||
// initialize draft batch
|
||||
|
|
@ -1037,39 +1013,6 @@ struct server_context {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool configure_slot_backend_sampler(server_slot & slot, const common_params_sampling & sampling) {
|
||||
if (!sampling.backend_sampling) {
|
||||
if (slot.backend_sampler != nullptr) {
|
||||
llama_set_backend_sampler(ctx, slot.id, nullptr);
|
||||
slot.backend_sampler = nullptr;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
llama_sampler * backend_chain = common_sampler_backend_init(model, sampling);
|
||||
// The sampler types configured with --samplers might not be supported
|
||||
// by backend samplers in which case we disable backend sampling and
|
||||
// fallback to CPU only sampling.
|
||||
if (backend_chain == nullptr) {
|
||||
if (slot.backend_sampler != nullptr) {
|
||||
llama_set_backend_sampler(ctx, slot.id, nullptr);
|
||||
slot.backend_sampler = nullptr;
|
||||
}
|
||||
SLT_INF(slot, "%s", "no backend samplers configured (sampler chain doesn't start with backend-supported samplers)\n");
|
||||
return true;
|
||||
}
|
||||
|
||||
if (slot.backend_sampler != nullptr) {
|
||||
llama_set_backend_sampler(ctx, slot.id, nullptr);
|
||||
slot.backend_sampler = nullptr;
|
||||
}
|
||||
|
||||
slot.backend_sampler = backend_chain;
|
||||
llama_set_backend_sampler(ctx, slot.id, backend_chain);
|
||||
SLT_INF(slot, "%s", "configured backend samplers\n");
|
||||
return true;
|
||||
}
|
||||
|
||||
bool process_token(completion_token_output & result, server_slot & slot) {
|
||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||
const std::string token_str = result.text_to_send;
|
||||
|
|
@ -1206,7 +1149,7 @@ struct server_context {
|
|||
size_t n_vocab = llama_vocab_n_tokens(vocab);
|
||||
|
||||
if (post_sampling) {
|
||||
const auto * cur_p = common_sampler_get_candidates(slot.smpl, true);
|
||||
const auto * cur_p = common_sampler_get_candidates(slot.smpl.get(), true);
|
||||
const size_t max_probs = cur_p->size;
|
||||
|
||||
// set probability for sampled token
|
||||
|
|
@ -2185,13 +2128,13 @@ struct server_context {
|
|||
|
||||
GGML_ASSERT(batch.n_tokens > 0);
|
||||
|
||||
common_sampler_reset(slot.smpl);
|
||||
common_sampler_reset(slot.smpl.get());
|
||||
|
||||
// Process all prompt tokens through sampler system
|
||||
for (int i = 0; i < slot.task->n_tokens(); ++i) {
|
||||
llama_token id = input_tokens[i];
|
||||
if (id != LLAMA_TOKEN_NULL) {
|
||||
common_sampler_accept(slot.smpl, id, false);
|
||||
common_sampler_accept(slot.smpl.get(), id, false);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2381,11 +2324,11 @@ struct server_context {
|
|||
|
||||
const int tok_idx = slot.i_batch - i;
|
||||
|
||||
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
|
||||
llama_token id = common_sampler_sample(slot.smpl.get(), ctx, tok_idx);
|
||||
|
||||
slot.i_batch = -1;
|
||||
|
||||
common_sampler_accept(slot.smpl, id, true);
|
||||
common_sampler_accept(slot.smpl.get(), id, true);
|
||||
|
||||
slot.n_decoded += 1;
|
||||
|
||||
|
|
@ -2488,7 +2431,7 @@ struct server_context {
|
|||
llama_decode(ctx, slot.batch_spec);
|
||||
|
||||
// the accepted tokens from the speculation
|
||||
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
|
||||
const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, draft);
|
||||
|
||||
slot.n_decoded += ids.size();
|
||||
|
||||
|
|
|
|||
|
|
@ -568,10 +568,10 @@ int main(int argc, char ** argv) {
|
|||
llama_context * ctx_ttc = NULL;
|
||||
llama_context * ctx_cts = NULL;
|
||||
|
||||
common_init_result llama_init_ttc = common_init_from_params(params);
|
||||
auto llama_init_ttc = common_init_from_params(params);
|
||||
|
||||
model_ttc = llama_init_ttc.model.get();
|
||||
ctx_ttc = llama_init_ttc.context.get();
|
||||
model_ttc = llama_init_ttc->model();
|
||||
ctx_ttc = llama_init_ttc->context();
|
||||
|
||||
if (model_ttc == nullptr || ctx_ttc == nullptr) {
|
||||
return ENOENT;
|
||||
|
|
@ -583,10 +583,10 @@ int main(int argc, char ** argv) {
|
|||
params.embedding = true;
|
||||
params.n_ubatch = params.n_batch;
|
||||
|
||||
common_init_result llama_init_cts = common_init_from_params(params);
|
||||
auto llama_init_cts = common_init_from_params(params);
|
||||
|
||||
model_cts = llama_init_cts.model.get();
|
||||
ctx_cts = llama_init_cts.context.get();
|
||||
model_cts = llama_init_cts->model();
|
||||
ctx_cts = llama_init_cts->context();
|
||||
|
||||
if (model_cts == nullptr || ctx_cts == nullptr) {
|
||||
return ENOENT;
|
||||
|
|
|
|||
Loading…
Reference in New Issue