diff --git a/common/common.cpp b/common/common.cpp index 6606ab5c59..321ea90e7e 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -950,31 +950,40 @@ std::vector fs_list_files(const std::string & path) { // Model utils // -static inline void common_init_sampler_from_model( +// TODO: move to common/sampling +static void common_init_sampler_from_model( const llama_model * model, common_params_sampling & sparams) { const uint64_t config = sparams.user_sampling_config; auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) { - if (config & user_config) return; + if (config & user_config) { + return; + } char buf[64] = {0}; if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) { char * end = nullptr; int32_t v = strtol(buf, &end, 10); - if (end && end != buf) dst = v; + if (end && end != buf) { + dst = v; + } } }; auto get_float = [&](const char * key, float & dst, uint64_t user_config) { - if (config & user_config) return; + if (config & user_config) { + return; + } char buf[128] = {0}; if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) { char * end = nullptr; float v = strtof(buf, &end); - if (end && end != buf) dst = v; + if (end && end != buf) { + dst = v; + } } }; @@ -1002,45 +1011,130 @@ static inline void common_init_sampler_from_model( get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA); } -struct common_init_result common_init_from_params(common_params & params) { - common_init_result iparams; - auto mparams = common_model_params_to_llama(params); +struct common_init_result::impl { + impl() = default; + ~impl() = default; + + llama_model_ptr model; + llama_context_ptr context; + + std::vector lora; + + std::vector samplers; + std::vector samplers_seq_config; +}; + +common_init_result::common_init_result(common_params & params) : + pimpl(new impl{}) { + const auto mparams = common_model_params_to_llama(params); llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams); if (model == NULL) { - LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n", - __func__, params.model.path.c_str()); - return iparams; + return; } - common_init_sampler_from_model(model, params.sampling); + pimpl->model.reset(model); const llama_vocab * vocab = llama_model_get_vocab(model); + // updates params.sampling + // TODO: fix naming + common_init_sampler_from_model(model, params.sampling); + auto cparams = common_context_params_to_llama(params); - if (params.sampling.backend_sampling) { - llama_sampler * backend_chain = common_sampler_backend_init(model, params.sampling); - if (backend_chain != nullptr) { - iparams.samplers_seq_config.resize(cparams.n_seq_max); - for (int i = 0; i < (int) cparams.n_seq_max; ++i) { - iparams.samplers_seq_config[i] = { i, llama_sampler_clone(backend_chain) }; - } - cparams.samplers = iparams.samplers_seq_config.data(); - cparams.n_samplers = cparams.n_seq_max; + if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) { + LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__); + params.sampling.ignore_eos = false; + } - llama_sampler_free(backend_chain); + // initialize once + for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) { + if (llama_vocab_is_eog(vocab, i)) { + LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(vocab, i).c_str(), -INFINITY); + params.sampling.logit_bias_eog.push_back({i, -INFINITY}); } } + if (params.sampling.ignore_eos) { + // add EOG biases to the active set of logit biases + params.sampling.logit_bias.insert( + params.sampling.logit_bias.end(), + params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end()); + } + + //if (params.sampling.penalty_last_n == -1) { + // LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx)); + // params.sampling.penalty_last_n = llama_n_ctx(lctx); + //} + + //if (params.sampling.dry_penalty_last_n == -1) { + // LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx)); + // params.sampling.dry_penalty_last_n = llama_n_ctx(lctx); + //} + + // init the backend samplers as part of the context creation + pimpl->samplers.resize(cparams.n_seq_max); + pimpl->samplers_seq_config.resize(cparams.n_seq_max); + + for (int i = 0; i < (int) cparams.n_seq_max; ++i) { + pimpl->samplers[i].reset(common_sampler_init(model, params.sampling)); + llama_sampler * backend_chain = common_sampler_chain_backend(pimpl->samplers[i].get()); + pimpl->samplers_seq_config[i] = { i, backend_chain }; + } + + cparams.samplers = pimpl->samplers_seq_config.data(); + cparams.n_samplers = pimpl->samplers_seq_config.size(); + llama_context * lctx = llama_init_from_model(model, cparams); if (lctx == NULL) { LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n", - __func__, params.model.path.c_str()); - llama_model_free(model); - return iparams; + __func__, params.model.path.c_str()); + return; } + pimpl->context.reset(lctx); +} + +llama_model * common_init_result::model() { + return pimpl->model.get(); +} + +llama_context * common_init_result::context() { + return pimpl->context.get(); +} + +common_sampler * common_init_result::sampler(llama_seq_id seq_id) { + return pimpl->samplers[seq_id].get(); +} + +std::vector & common_init_result::lora() { + return pimpl->lora; +} + +void common_init_result::free_context() { + pimpl->context.reset(); +} + +common_init_result_ptr common_init_from_params(common_params & params) { + common_init_result_ptr res(new common_init_result(params)); + + llama_model * model = res->model(); + if (model == NULL) { + LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n", + __func__, params.model.path.c_str()); + return res; + } + + llama_context * lctx = res->context(); + if (lctx == NULL) { + LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n", + __func__, params.model.path.c_str()); + return res; + } + + const llama_vocab * vocab = llama_model_get_vocab(model); + if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) { LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__); params.ctx_shift = false; @@ -1052,10 +1146,7 @@ struct common_init_result common_init_from_params(common_params & params) { const auto cvec = common_control_vector_load(params.control_vectors); if (cvec.n_embd == -1) { - llama_free(lctx); - llama_model_free(model); - - return iparams; + return res; } int err = llama_apply_adapter_cvec( @@ -1066,10 +1157,7 @@ struct common_init_result common_init_from_params(common_params & params) { params.control_vector_layer_start, params.control_vector_layer_end); if (err) { - llama_free(lctx); - llama_model_free(model); - - return iparams; + return res; } } @@ -1093,10 +1181,7 @@ struct common_init_result common_init_from_params(common_params & params) { } if (!ok) { - llama_free(lctx); - llama_model_free(model); - - return iparams; + return res; } } @@ -1106,9 +1191,7 @@ struct common_init_result common_init_from_params(common_params & params) { lora.reset(llama_adapter_lora_init(model, la.path.c_str())); if (lora == nullptr) { LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str()); - llama_free(lctx); - llama_model_free(model); - return iparams; + return res; } char buf[1024]; @@ -1117,43 +1200,13 @@ struct common_init_result common_init_from_params(common_params & params) { la.task_name = buf; llama_adapter_meta_val_str(la.ptr, "adapter.lora.prompt_prefix", buf, sizeof(buf)); la.prompt_prefix = buf; - iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters + res->lora().emplace_back(std::move(lora)); // copy to list of loaded adapters } if (!params.lora_init_without_apply) { common_set_adapter_lora(lctx, params.lora_adapters); } - if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) { - LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__); - params.sampling.ignore_eos = false; - } - - // initialize once - for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) { - if (llama_vocab_is_eog(vocab, i)) { - LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY); - params.sampling.logit_bias_eog.push_back({i, -INFINITY}); - } - } - - if (params.sampling.ignore_eos) { - // add EOG biases to the active set of logit biases - params.sampling.logit_bias.insert( - params.sampling.logit_bias.end(), - params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end()); - } - - if (params.sampling.penalty_last_n == -1) { - LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx)); - params.sampling.penalty_last_n = llama_n_ctx(lctx); - } - - if (params.sampling.dry_penalty_last_n == -1) { - LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx)); - params.sampling.dry_penalty_last_n = llama_n_ctx(lctx); - } - if (params.warmup) { LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__); @@ -1192,12 +1245,11 @@ struct common_init_result common_init_from_params(common_params & params) { llama_set_warmup(lctx, false); } - iparams.model.reset(model); - iparams.context.reset(lctx); - - return iparams; + return res; } +common_init_result::~common_init_result() = default; + std::string get_model_endpoint() { const char * model_endpoint_env = getenv("MODEL_ENDPOINT"); // We still respect the use of environment-variable "HF_ENDPOINT" for backward-compatibility. @@ -1206,7 +1258,9 @@ std::string get_model_endpoint() { std::string model_endpoint = "https://huggingface.co/"; if (endpoint_env) { model_endpoint = endpoint_env; - if (model_endpoint.back() != '/') model_endpoint += '/'; + if (model_endpoint.back() != '/') { + model_endpoint += '/'; + } } return model_endpoint; } diff --git a/common/common.h b/common/common.h index 8e78ab32ea..9b53d2b56f 100644 --- a/common/common.h +++ b/common/common.h @@ -192,7 +192,6 @@ struct common_params_sampling { std::vector dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY - std::vector samplers = { COMMON_SAMPLER_TYPE_PENALTIES, COMMON_SAMPLER_TYPE_DRY, @@ -215,6 +214,16 @@ struct common_params_sampling { bool backend_sampling = false; // enable backend sampling + bool has_logit_bias() const { + return !logit_bias.empty(); + } + + bool is_disabled(enum common_sampler_type type) const; + + // remove disabled samplers + // TODO: temporary until all samplers have llama_sampler_backend_ API [LLAMA_SAMPLER_BACKEND] + void filter_disabled(); + // print the parameters into a string std::string print() const; }; @@ -650,18 +659,29 @@ std::vector fs_list_files(const std::string & path); // Model utils // +struct common_sampler; + // note: defines object's lifetime struct common_init_result { - llama_model_ptr model; - llama_context_ptr context; + common_init_result(common_params & params); + ~common_init_result(); - std::vector lora; + llama_model * model(); + llama_context * context(); + common_sampler * sampler(llama_seq_id seq_id); - std::vector samplers; - std::vector samplers_seq_config; + std::vector & lora(); + + void free_context(); + +private: + struct impl; + std::unique_ptr pimpl; }; -struct common_init_result common_init_from_params(common_params & params); +using common_init_result_ptr = std::unique_ptr; + +common_init_result_ptr common_init_from_params(common_params & params); struct llama_model_params common_model_params_to_llama ( common_params & params); struct llama_context_params common_context_params_to_llama(const common_params & params); diff --git a/common/sampling.cpp b/common/sampling.cpp index 9954e2519d..a831eac18b 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -105,7 +105,8 @@ struct common_sampler { common_params_sampling params; struct llama_sampler * grmr; - struct llama_sampler * chain; // CPU sampling chain + struct llama_sampler * chain; + struct llama_sampler * chain_backend; ring_buffer prev; @@ -118,6 +119,7 @@ struct common_sampler { llama_sampler_reset(grmr); llama_sampler_reset(chain); + llama_sampler_reset(chain_backend); } void set_logits(struct llama_context * ctx, int idx) { @@ -161,7 +163,8 @@ struct common_sampler { mutable int64_t t_total_us = 0; }; -static bool sampler_backend_supported(enum common_sampler_type type) { +// TODO: temporary until all samplers have llama_sampler_backend_ API [LLAMA_SAMPLER_BACKEND] +static bool common_sampler_type_has_backend_support(enum common_sampler_type type) { switch (type) { case COMMON_SAMPLER_TYPE_TOP_K: case COMMON_SAMPLER_TYPE_TEMPERATURE: @@ -172,98 +175,69 @@ static bool sampler_backend_supported(enum common_sampler_type type) { } } -static bool is_sampler_enabled(enum common_sampler_type type, const struct common_params_sampling & params) { +bool common_params_sampling::is_disabled(enum common_sampler_type type) const { switch (type) { case COMMON_SAMPLER_TYPE_PENALTIES: - if (params.penalty_last_n == 64 && - fabs(params.penalty_repeat) <= 1.0f && - fabs(params.penalty_freq) <= 0.0f && - fabs(params.penalty_present) <= 0.0f) { - return false; + if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) { + return true; } break; case COMMON_SAMPLER_TYPE_DRY: - if (params.dry_multiplier == 0.0f && params.dry_base == 1.75f) { - return false; + if (dry_multiplier == 0.0f || dry_base < 1.0f || dry_penalty_last_n == 0) { + return true; } break; case COMMON_SAMPLER_TYPE_TYPICAL_P: - if (params.typ_p == 1.0) { - return false; + if (typ_p >= 1.0) { + return true; } break; case COMMON_SAMPLER_TYPE_TOP_N_SIGMA: - if (params.top_n_sigma == -1.0) { - return false; + if (top_n_sigma <= 0.0) { + return true; } break; case COMMON_SAMPLER_TYPE_TOP_K: - if (params.top_k <= 0) { - return false; + if (top_k <= 0) { + return true; } break; case COMMON_SAMPLER_TYPE_TEMPERATURE: - if (params.temp < 0.0f) { - return false; + if (dynatemp_range <= 0.0f) { + return true; } break; case COMMON_SAMPLER_TYPE_MIN_P: - if (params.min_p <= 0.0f) { - return false; + if (min_p <= 0.0f) { + return true; } break; case COMMON_SAMPLER_TYPE_TOP_P: - if (params.top_p >= 1.0f) { - return false; + if (top_p >= 1.0f) { + return true; } break; case COMMON_SAMPLER_TYPE_XTC: - if (params.xtc_probability == 0.0f && params.xtc_threshold == 0.10f) { - return false; + if (xtc_probability <= 0.0f || xtc_threshold == 0.50f) { + return true; } break; default: break; } - return true; + + return false; } -static bool has_logit_bias(const struct common_params_sampling & params) { - return !params.logit_bias.empty(); -} - -struct active_samplers { - std::vector backend_samplers; - std::vector cpu_samplers; -}; - -static struct active_samplers get_active_samplers(const struct common_params_sampling & params) { - struct active_samplers result; - - if (params.mirostat != 0) { - // Mirostat is CPU-only and overrides other samplers - for (const auto & sampler_type : params.samplers) { - if (is_sampler_enabled(sampler_type, params)) { - result.cpu_samplers.push_back(sampler_type); - } - } - return result; - } - - bool backend_supported = params.backend_sampling; - - for (const auto & sampler_type : params.samplers) { - if (!is_sampler_enabled(sampler_type, params)) { - continue; - } - - if (backend_supported && sampler_backend_supported(sampler_type)) { - result.backend_samplers.push_back(sampler_type); +void common_params_sampling::filter_disabled() { + for (auto it = samplers.begin(); it != samplers.end();) { + if (is_disabled(*it)) { + LOG_WRN("%s: removing disabled sampler %s\n", __func__, common_sampler_type_to_str(*it).c_str()); + it = samplers.erase(it); } else { - result.cpu_samplers.push_back(sampler_type); + ++it; } } - return result; } std::string common_params_sampling::print() const { @@ -282,15 +256,7 @@ std::string common_params_sampling::print() const { return std::string(result); } -struct backend_chain_data { - struct llama_sampler * chain; - size_t count; -}; - -static struct backend_chain_data backend_samplers_init(const struct llama_model * model, const struct common_params_sampling & params, - struct active_samplers get_active_samplers); - -struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) { +struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params) { const llama_vocab * vocab = llama_model_get_vocab(model); llama_sampler_chain_params lparams = llama_sampler_chain_default_params(); @@ -357,29 +323,74 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co } } + // TODO: temporary until all samplers have llama_sampler_backend_ API [LLAMA_SAMPLER_BACKEND] + if (params.backend_sampling) { + params.filter_disabled(); + } + auto * result = new common_sampler { /* .params = */ params, /* .grmr = */ grmr, /* .chain = */ llama_sampler_chain_init(lparams), + /* .chain_backend = */ llama_sampler_chain_init(lparams), /* .prev = */ ring_buffer(std::max(32, params.n_prev)), /* .cur = */ {}, /* .cur_p = */ {}, }; - struct active_samplers active_samplers = get_active_samplers(params); + size_t idx_smpl = 0; - // Build CPU chain - if (!params.backend_sampling || !has_logit_bias(params)) { - llama_sampler_chain_add(result->chain, - llama_sampler_init_logit_bias( - llama_vocab_n_tokens(vocab), - params.logit_bias.size(), - params.logit_bias.data())); + bool is_backend = true; + + is_backend = is_backend && params.backend_sampling; + is_backend = is_backend && (params.samplers.size() == 0 || common_sampler_type_has_backend_support(params.samplers[idx_smpl])); + + if (params.has_logit_bias()) { + if (is_backend) { + llama_sampler_chain_add(result->chain_backend, + llama_sampler_backend_init_logit_bias( + llama_vocab_n_tokens(vocab), + params.logit_bias.size(), + params.logit_bias.data())); + } else { + llama_sampler_chain_add(result->chain, + llama_sampler_init_logit_bias( + llama_vocab_n_tokens(vocab), + params.logit_bias.size(), + params.logit_bias.data())); + } } if (params.mirostat == 0) { + // backend samplers are added first + while (is_backend && idx_smpl < params.samplers.size()) { + const auto & cnstr = params.samplers[idx_smpl++]; + + if (!common_sampler_type_has_backend_support(cnstr)) { + is_backend = false; + --idx_smpl; + break; + } + + switch (cnstr) { + case COMMON_SAMPLER_TYPE_TOP_K: + llama_sampler_chain_add(result->chain_backend, llama_sampler_backend_init_top_k(params.top_k)); + break; + case COMMON_SAMPLER_TYPE_TEMPERATURE: + llama_sampler_chain_add(result->chain_backend, llama_sampler_backend_init_temp(params.temp)); + break; + case COMMON_SAMPLER_TYPE_MIN_P: + llama_sampler_chain_add(result->chain_backend, llama_sampler_backend_init_min_p(params.min_p)); + break; + default: + GGML_ASSERT(false && "unsupported backend sampler"); + } + } + // Add remaining CPU samplers - for (const auto & cnstr : active_samplers.cpu_samplers) { + while (idx_smpl < params.samplers.size()) { + const auto & cnstr = params.samplers[idx_smpl++]; + switch (cnstr) { case COMMON_SAMPLER_TYPE_DRY: { @@ -424,7 +435,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co } } - if (!active_samplers.cpu_samplers.empty()) { + if (is_backend) { + llama_sampler_chain_add(result->chain_backend, llama_sampler_backend_init_dist(params.seed)); + } else { llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); } } else if (params.mirostat == 1) { @@ -440,59 +453,11 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co return result; } - -static struct backend_chain_data backend_samplers_init(const struct llama_model * model, const struct common_params_sampling & params, - struct active_samplers active_samplers) { - if (active_samplers.backend_samplers.empty()) { - return { nullptr, 0 }; - } - - const llama_vocab * vocab = llama_model_get_vocab(model); - - llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); - chain_params.no_perf = params.no_perf; - struct llama_sampler * chain = llama_sampler_chain_init(chain_params); - - // Add logit_bias to backend chain if present - if (has_logit_bias(params)) { - llama_sampler_chain_add(chain, llama_sampler_backend_init_logit_bias( - llama_vocab_n_tokens(vocab), - params.logit_bias.size(), - params.logit_bias.data())); - } - - for (const auto & sampler_type : active_samplers.backend_samplers) { - switch (sampler_type) { - case COMMON_SAMPLER_TYPE_TOP_K: - llama_sampler_chain_add(chain, llama_sampler_backend_init_top_k(params.top_k)); - break; - case COMMON_SAMPLER_TYPE_TEMPERATURE: - llama_sampler_chain_add(chain, llama_sampler_backend_init_temp(params.temp)); - break; - case COMMON_SAMPLER_TYPE_MIN_P: - llama_sampler_chain_add(chain, llama_sampler_backend_init_min_p(params.min_p)); - break; - default: - GGML_ASSERT(false && "unsupported backend sampler"); - } - } - - if (active_samplers.cpu_samplers.empty()) { - llama_sampler_chain_add(chain, llama_sampler_backend_init_dist(params.seed)); - } - - return { chain, active_samplers.backend_samplers.size() + has_logit_bias(params) }; -} - -struct llama_sampler * common_sampler_backend_init(const struct llama_model * model, const struct common_params_sampling & params) { - struct active_samplers active_samplers = get_active_samplers(params); - return backend_samplers_init(model, params, active_samplers).chain; -} - void common_sampler_free(struct common_sampler * gsmpl) { if (gsmpl) { llama_sampler_free(gsmpl->grmr); llama_sampler_free(gsmpl->chain); + llama_sampler_free(gsmpl->chain_backend); delete gsmpl; } @@ -519,6 +484,7 @@ struct common_sampler * common_sampler_clone(common_sampler * gsmpl) { /* .params = */ gsmpl->params, /* .grmr = */ llama_sampler_clone(gsmpl->grmr), /* .chain = */ llama_sampler_clone(gsmpl->chain), + /* .chain_backend = */ llama_sampler_clone(gsmpl->chain_backend), /* .prev = */ gsmpl->prev, /* .cur = */ gsmpl->cur, /* .cur_p = */ gsmpl->cur_p, @@ -570,6 +536,10 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam } } +struct llama_sampler * common_sampler_chain_backend(const struct common_sampler * gsmpl) { + return gsmpl->chain_backend; +} + llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) { // Check if a backend sampler has already sampled a token in which case we // return that token id directly. @@ -707,7 +677,12 @@ llama_token common_sampler_last(const struct common_sampler * gsmpl) { } std::string common_sampler_print(const struct common_sampler * gsmpl) { - std::string result = "logits "; + std::string result = llama_sampler_chain_n(gsmpl->chain_backend) > 0 ? "*logits " : "logits "; + + for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain_backend); i++) { + const auto * smpl = llama_sampler_chain_get(gsmpl->chain_backend, i); + result += std::string("-> *") + llama_sampler_name(smpl) + " "; + } for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) { const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i); diff --git a/common/sampling.h b/common/sampling.h index 0ec164de05..06f27923a0 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -36,14 +36,8 @@ struct common_sampler; // llama_sampler API overloads -struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params); - -// Create a backend sampler chain from common sampling parameters -// Returns a llama_sampler chain configured with backend samplers based on the parameters -// This chain can be used per-sequence for backend-based sampling -// Note: Only samplers that have backend equivalents will be added to the chain -// The returned sampler should be freed with llama_sampler_free() -struct llama_sampler * common_sampler_backend_init(const struct llama_model * model, const struct common_params_sampling & params); +// TODO: params should become const again [LLAMA_SAMPLER_BACKEND] +struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params); void common_sampler_free(struct common_sampler * gsmpl); @@ -55,6 +49,8 @@ struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl); // arguments can be nullptr to skip printing void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl); +struct llama_sampler * common_sampler_chain_backend(const struct common_sampler * gsmpl); + // extended sampling implementation: // // - set logits @@ -114,3 +110,9 @@ std::vector common_sampler_types_from_chars(const std: llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind, const char * grammar_data); + +struct common_sampler_deleter { + void operator()(common_sampler * s) { common_sampler_free(s); } +}; + +typedef std::unique_ptr common_sampler_ptr; diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index e9d1fc95c2..e23a3bab21 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -65,29 +65,34 @@ int main(int argc, char ** argv) { ctx_params.n_ctx = n_kv_req; ctx_params.n_batch = std::max(n_predict, n_parallel); - std::vector sampler_configs(n_parallel); - if (params.sampling.backend_sampling) { - for (int32_t i = 0; i < n_parallel; ++i) { - llama_sampler * backend_sampler = common_sampler_backend_init(model, params.sampling); - if (backend_sampler) { - sampler_configs[i] = { i, backend_sampler }; - } - } - ctx_params.samplers = sampler_configs.data(); - ctx_params.n_samplers = n_parallel; - } - - llama_context * ctx = llama_init_from_model(model, ctx_params); - auto sparams = llama_sampler_chain_default_params(); sparams.no_perf = false; - llama_sampler * smpl = llama_sampler_chain_init(sparams); + std::vector sampler_configs; - llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k)); - llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep)); - llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp)); - llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed)); + for (int32_t i = 0; i < n_parallel; ++i) { + llama_sampler * smpl = llama_sampler_chain_init(sparams); + + if (params.sampling.backend_sampling) { + llama_sampler_chain_add(smpl, llama_sampler_backend_init_top_k(params.sampling.top_k)); + llama_sampler_chain_add(smpl, llama_sampler_backend_init_temp (params.sampling.temp)); + llama_sampler_chain_add(smpl, llama_sampler_backend_init_dist (params.sampling.seed)); + } else { + llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k)); + llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep)); + llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp)); + llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed)); + } + + sampler_configs.push_back({ i, smpl }); + } + + if (params.sampling.backend_sampling) { + ctx_params.samplers = sampler_configs.data(); + ctx_params.n_samplers = sampler_configs.size(); + } + + llama_context * ctx = llama_init_from_model(model, ctx_params); if (ctx == NULL) { LOG_ERR("%s: error: failed to create the llama_context\n" , __func__); @@ -186,7 +191,7 @@ int main(int argc, char ** argv) { continue; } - const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]); + const llama_token new_token_id = llama_sampler_sample(sampler_configs[i].sampler, ctx, i_batch[i]); // is it an end of generation? -> mark the stream as finished if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_predict) { @@ -242,14 +247,17 @@ int main(int argc, char ** argv) { __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); LOG("\n"); - llama_perf_sampler_print(smpl); + llama_perf_sampler_print(sampler_configs[0].sampler); llama_perf_context_print(ctx); fprintf(stderr, "\n"); llama_batch_free(batch); - llama_sampler_free(smpl); + for (auto & sampler_config : sampler_configs) { + llama_sampler_free(sampler_config.sampler); + } + llama_free(ctx); llama_model_free(model); diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 9e3ab5905b..ccf5bb6f55 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -130,10 +130,10 @@ int main(int argc, char ** argv) { llama_numa_init(params.numa); // load the model - common_init_result llama_init = common_init_from_params(params); + auto llama_init = common_init_from_params(params); - llama_model * model = llama_init.model.get(); - llama_context * ctx = llama_init.context.get(); + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); if (model == NULL) { LOG_ERR("%s: unable to load model\n", __func__); diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index 80c693ce61..408338f1af 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -202,10 +202,10 @@ int main(int argc, char ** argv) { params.warmup = false; // init - common_init_result llama_init = common_init_from_params(params); + auto llama_init = common_init_from_params(params); - llama_model * model = llama_init.model.get(); - llama_context * ctx = llama_init.context.get(); + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); if (model == nullptr || ctx == nullptr) { LOG_ERR("%s : failed to init\n", __func__); diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 1e26d8221b..f54cfdd77f 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -55,10 +55,10 @@ int main(int argc, char ** argv) { llama_numa_init(params.numa); // load the target model - common_init_result llama_init = common_init_from_params(params); + auto llama_init = common_init_from_params(params); - llama_model * model = llama_init.model.get(); - llama_context * ctx = llama_init.context.get(); + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); auto * mem = llama_get_memory(ctx); diff --git a/examples/lookup/lookup-create.cpp b/examples/lookup/lookup-create.cpp index 3da45ed9e0..bb94a8fe06 100644 --- a/examples/lookup/lookup-create.cpp +++ b/examples/lookup/lookup-create.cpp @@ -18,16 +18,16 @@ int main(int argc, char ** argv){ llama_numa_init(params.numa); // load the model - common_init_result llama_init = common_init_from_params(params); + auto llama_init = common_init_from_params(params); - llama_model_ptr & model = llama_init.model; - llama_context_ptr & ctx = llama_init.context; + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); GGML_ASSERT(model != nullptr); // tokenize the prompt std::vector inp; - inp = common_tokenize(ctx.get(), params.prompt, true, true); + inp = common_tokenize(ctx, params.prompt, true, true); fprintf(stderr, "%s: tokenization done\n", __func__); common_ngram_cache ngram_cache; diff --git a/examples/lookup/lookup-stats.cpp b/examples/lookup/lookup-stats.cpp index fcb289abe0..135f6fcab9 100644 --- a/examples/lookup/lookup-stats.cpp +++ b/examples/lookup/lookup-stats.cpp @@ -28,13 +28,13 @@ int main(int argc, char ** argv){ llama_numa_init(params.numa); // load the model - common_init_result llama_init = common_init_from_params(params); + auto llama_init = common_init_from_params(params); - llama_context_ptr & ctx = llama_init.context; + llama_context * ctx = llama_init->context(); // tokenize the prompt std::vector inp; - inp = common_tokenize(ctx.get(), params.prompt, true, true); + inp = common_tokenize(ctx, params.prompt, true, true); common_ngram_cache ngram_cache_context; common_ngram_cache ngram_cache_dynamic; @@ -65,7 +65,7 @@ int main(int argc, char ** argv){ } const int n_input = inp.size(); - const int n_ctx = llama_n_ctx(ctx.get()); + const int n_ctx = llama_n_ctx(ctx); int n_drafted = 0; int n_accept = 0; diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index 2bfa26b55f..27f159940a 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -29,10 +29,10 @@ int main(int argc, char ** argv){ llama_numa_init(params.numa); // load the model - common_init_result llama_init = common_init_from_params(params); + auto llama_init = common_init_from_params(params); - llama_model * model = llama_init.model.get(); - llama_context * ctx = llama_init.context.get(); + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); const llama_vocab * vocab = llama_model_get_vocab(model); diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index e48f48fc32..c92173ae29 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -192,10 +192,10 @@ int main(int argc, char ** argv) { llama_numa_init(params.numa); // load the target model - common_init_result llama_init = common_init_from_params(params); + auto llama_init = common_init_from_params(params); - llama_model * model = llama_init.model.get(); - llama_context * ctx = llama_init.context.get(); + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); auto * mem = llama_get_memory(ctx); diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 042e12c2bf..2c2143ad10 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -149,10 +149,10 @@ int main(int argc, char ** argv) { llama_numa_init(params.numa); // load the model - common_init_result llama_init = common_init_from_params(params); + auto llama_init = common_init_from_params(params); - llama_model * model = llama_init.model.get(); - llama_context * ctx = llama_init.context.get(); + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); if (model == NULL) { LOG_ERR("%s: unable to load model\n", __func__); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 1065ec6bb0..615929c1e0 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -34,10 +34,10 @@ int main(int argc, char ** argv) { std::string result2; // init - common_init_result llama_init = common_init_from_params(params); + auto llama_init = common_init_from_params(params); - llama_model * model = llama_init.model.get(); - llama_context * ctx = llama_init.context.get(); + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); if (model == nullptr || ctx == nullptr) { fprintf(stderr, "%s : failed to init\n", __func__); diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index a8e53f28eb..a5bd5c9125 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -40,10 +40,10 @@ int main(int argc, char ** argv) { llama_context * ctx_dft = NULL; // load the target model - common_init_result llama_init_tgt = common_init_from_params(params); + auto llama_init_tgt = common_init_from_params(params); - model_tgt = llama_init_tgt.model.get(); - ctx_tgt = llama_init_tgt.context.get(); + model_tgt = llama_init_tgt->model(); + ctx_tgt = llama_init_tgt->context(); const llama_vocab * vocab = llama_model_get_vocab(model_tgt); @@ -61,10 +61,10 @@ int main(int argc, char ** argv) { params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads; params.tensor_buft_overrides = params.speculative.tensor_buft_overrides; - common_init_result llama_init_dft = common_init_from_params(params); + auto llama_init_dft = common_init_from_params(params); - //model_dft = llama_init_dft.model.get(); - ctx_dft = llama_init_dft.context.get(); + //model_dft = llama_init_dft->model(); + ctx_dft = llama_init_dft->context(); if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) { LOG_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params.speculative.model.path.c_str(), params.model.path.c_str()); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 5f5ac5eb64..89d3249431 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -71,10 +71,10 @@ int main(int argc, char ** argv) { llama_context * ctx_dft = NULL; // load the target model - common_init_result llama_init_tgt = common_init_from_params(params); + auto llama_init_tgt = common_init_from_params(params); - model_tgt = llama_init_tgt.model.get(); - ctx_tgt = llama_init_tgt.context.get(); + model_tgt = llama_init_tgt->model(); + ctx_tgt = llama_init_tgt->context(); // load the draft model params.devices = params.speculative.devices; @@ -87,10 +87,10 @@ int main(int argc, char ** argv) { params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads; params.tensor_buft_overrides = params.speculative.tensor_buft_overrides; - common_init_result llama_init_dft = common_init_from_params(params); + auto llama_init_dft = common_init_from_params(params); - model_dft = llama_init_dft.model.get(); - ctx_dft = llama_init_dft.context.get(); + model_dft = llama_init_dft->model(); + ctx_dft = llama_init_dft->context(); const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt); const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft); diff --git a/examples/training/finetune.cpp b/examples/training/finetune.cpp index 416d8d8f6c..c82de8d35d 100644 --- a/examples/training/finetune.cpp +++ b/examples/training/finetune.cpp @@ -39,9 +39,10 @@ int main(int argc, char ** argv) { llama_backend_init(); llama_numa_init(params.numa); // load the model and apply lora adapter, if any - common_init_result llama_init = common_init_from_params(params); - llama_model_ptr & model = llama_init.model; - llama_context_ptr & ctx = llama_init.context; + auto llama_init = common_init_from_params(params); + + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); if (model == NULL) { LOG_ERR("%s: unable to load model\n", __func__); @@ -54,8 +55,8 @@ int main(int argc, char ** argv) { LOG_INF("%s\n", common_params_get_system_info(params).c_str()); } - std::vector tokens = common_tokenize(ctx.get(), params.prompt, true); - ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get()) / 2); + std::vector tokens = common_tokenize(ctx, params.prompt, true); + ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx, tokens, llama_n_ctx(ctx) / 2); struct lr_opt & lr = params.lr; LOG_INF("-optimizer %s -lr0 %.2g -wd %.2g -lr-min %.2g -min-epochs %.2g -epochs %d -period %.2g -val %.2g\n", @@ -70,7 +71,7 @@ int main(int argc, char ** argv) { /*get_opt_pars_ud =*/¶ms.lr, /*optimizer_type =*/params.optimizer, }; - llama_opt_init(ctx.get(), model.get(), lopt_params); + llama_opt_init(ctx, model, lopt_params); const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - params.val_split); @@ -78,7 +79,7 @@ int main(int argc, char ** argv) { ggml_opt_result_t result_eval = ggml_opt_result_init(); for (lr.epoch = 0; lr.epoch < lr.epochs; ++lr.epoch) { - llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split, + llama_opt_epoch(ctx, dataset, result_train, result_eval, idata_split, ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar); fprintf(stderr, "\n"); @@ -88,7 +89,7 @@ int main(int argc, char ** argv) { ggml_opt_result_free(result_train); ggml_opt_result_free(result_eval); - llama_model_save_to_file(model.get(), params.out_file.c_str()); + llama_model_save_to_file(model, params.out_file.c_str()); llama_backend_free(); diff --git a/include/llama.h b/include/llama.h index 50a4cc7c13..080ac27f1f 100644 --- a/include/llama.h +++ b/include/llama.h @@ -376,7 +376,7 @@ extern "C" { // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix // ref: https://github.com/ggml-org/llama.cpp/pull/14363 - // backend sampler chain configuration + // backend sampler chain configuration (does not keep a reference, so make sure the caller keeps the samplers alive) struct llama_sampler_seq_config * samplers; size_t n_samplers; }; @@ -1209,10 +1209,6 @@ extern "C" { void (*init_ggml)(struct llama_sampler * smpl, ggml_backend_buffer_type_t buft); - - - // TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph - //void (*apply_ggml) (struct llama_sampler * smpl, ...); }; struct llama_sampler { @@ -1231,18 +1227,17 @@ extern "C" { LLAMA_API struct llama_sampler * llama_sampler_clone (const struct llama_sampler * smpl); // important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add) LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl); - LLAMA_API void llama_sampler_init_ggml(struct llama_sampler * smpl, - ggml_backend_buffer_type_t buft); - LLAMA_API void llama_sampler_set_input_ggml(struct llama_sampler * smpl); - LLAMA_API void llama_sampler_apply_ggml( struct llama_sampler * smpl, - struct ggml_context * ctx, - struct ggml_cgraph * gf, - struct llama_sampler_ggml_data * ggml_data); - LLAMA_API void llama_sampler_accept_ggml( struct llama_sampler * smpl, - struct ggml_context * ctx, - struct ggml_cgraph * gf, - struct ggml_tensor * selected_token); + LLAMA_API void llama_sampler_init_ggml (struct llama_sampler * smpl, ggml_backend_buffer_type_t buft); + LLAMA_API void llama_sampler_set_input_ggml(struct llama_sampler * smpl); + LLAMA_API void llama_sampler_apply_ggml (struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_ggml_data * ggml_data); + LLAMA_API void llama_sampler_accept_ggml (struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct ggml_tensor * selected_token); // llama_sampler_chain // a type of llama_sampler that can chain multiple samplers one after another diff --git a/src/llama-context.cpp b/src/llama-context.cpp index b8c5accff8..3e15789f28 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -66,7 +66,15 @@ llama_context::llama_context( for (size_t i = 0; i < params.n_samplers; ++i) { const auto & config = params.samplers[i]; + + const int n_samplers = llama_sampler_chain_n(config.sampler); + if (n_samplers <= 0) { + continue; + } + sampling.samplers[config.seq_id] = config.sampler; + + LLAMA_LOG_INFO("%s: setting backend sampler for seq_id %d (n = %d)\n", __func__, config.seq_id, n_samplers); } } @@ -438,8 +446,8 @@ llama_context::llama_context( // Initialize the full vocabulary token ids for backend samplers. { - const llama_vocab * vocab = llama_model_get_vocab(&model); - const int n_vocab = llama_vocab_n_tokens(vocab); + const int n_vocab = model.vocab.n_tokens(); + sampling.token_ids_full_vocab.resize(n_vocab); for (int i = 0; i < n_vocab; ++i) { sampling.token_ids_full_vocab[i] = i; @@ -449,10 +457,6 @@ llama_context::llama_context( llama_context::~llama_context() { ggml_opt_free(opt_ctx); - // TODO: perhaps use a smart pointer for samplers - for (auto const& [seq_id, sampler] : sampling.samplers) { - llama_sampler_free(sampler); - } } void llama_context::synchronize() { @@ -910,31 +914,10 @@ void llama_context::set_warmup(bool value) { void llama_context::set_backend_sampler(llama_seq_id seq_id, llama_sampler * sampler) { LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler); - auto it = sampling.samplers.find(seq_id); - if (it != sampling.samplers.end()) { - // If the sampler to be set is the same that is already set, do nothing. - if (it->second == sampler) { - return; - } - - llama_sampler_free(it->second); - - // If sampler is nullptr, we remove the samppler chain for this seq_id. - // chain for this seq_id. - if (sampler == nullptr) { - sampling.samplers.erase(it); - return; - } - - // Otherwise, we replace the existing sampler with the new one. - it->second = sampler; - return; - } - - // If there is no sampler for this seq_id and the caller provides a non-null - // sampler, we set it. - if (sampler != nullptr) { + if (sampler != nullptr && llama_sampler_chain_n(sampler) > 0) { sampling.samplers[seq_id] = sampler; + } else { + sampling.samplers.erase(seq_id); } } @@ -1700,8 +1683,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba sampling.sampled_size = n_outputs_max; sampling.candidates_size = n_vocab*n_outputs_max; - backend_float_count = sampling.logits_size + sampling.probs_size; - backend_token_count = sampling.sampled_size + sampling.candidates_size; + backend_float_count = sampling.logits_size + sampling.probs_size; + backend_token_count = sampling.sampled_size + sampling.candidates_size; } if (output_ids.empty()) { diff --git a/src/llama-context.h b/src/llama-context.h index 1dcd3bf419..2940d337a8 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -258,7 +258,7 @@ private: float * logits = nullptr; struct sampling_info { - std::unordered_map samplers; + std::unordered_map samplers; float * logits = nullptr; size_t logits_size = 0; diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 5c3214f029..4d1760a629 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -439,10 +439,10 @@ void llama_sampler_free(struct llama_sampler * smpl) { } llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) { - const llama_token sampled_token = llama_get_backend_sampled_token_ith(ctx, idx); - const float * sampled_probs = llama_get_backend_sampled_probs_ith(ctx, idx); - const float * sampled_logits = llama_get_backend_sampled_logits_ith(ctx, idx); - const llama_token * sampled_ids = llama_get_backend_sampled_candidates_ith(ctx, idx); + const llama_token sampled_token = llama_get_backend_sampled_token_ith (ctx, idx); + const float * sampled_probs = llama_get_backend_sampled_probs_ith (ctx, idx); + const float * sampled_logits = llama_get_backend_sampled_logits_ith (ctx, idx); + const llama_token * sampled_ids = llama_get_backend_sampled_candidates_ith(ctx, idx); // If a backend sampler has already sampled a token, return it. if (sampled_token != LLAMA_TOKEN_NULL) { diff --git a/tools/cvector-generator/cvector-generator.cpp b/tools/cvector-generator/cvector-generator.cpp index d2d97e05ce..3ba7c52950 100644 --- a/tools/cvector-generator/cvector-generator.cpp +++ b/tools/cvector-generator/cvector-generator.cpp @@ -419,10 +419,10 @@ int main(int argc, char ** argv) { llama_numa_init(params.numa); // load the model to get hparams - common_init_result llama_init = common_init_from_params(params); + auto llama_init = common_init_from_params(params); - llama_model * model = llama_init.model.get(); - llama_context * ctx = llama_init.context.get(); + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); // int n_ctx = llama_n_ctx(ctx); int n_layers = llama_model_n_layer(model); diff --git a/tools/imatrix/imatrix.cpp b/tools/imatrix/imatrix.cpp index f28a036dee..669de55ddb 100644 --- a/tools/imatrix/imatrix.cpp +++ b/tools/imatrix/imatrix.cpp @@ -1265,10 +1265,10 @@ int main(int argc, char ** argv) { params.warmup = false; // init - common_init_result llama_init = common_init_from_params(params); + auto llama_init = common_init_from_params(params); - llama_model * model = llama_init.model.get(); - llama_context * ctx = llama_init.context.get(); + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); if (model == nullptr || ctx == nullptr) { LOG_ERR("%s : failed to init\n", __func__); diff --git a/tools/main/main.cpp b/tools/main/main.cpp index 263387f417..6f64708dcd 100644 --- a/tools/main/main.cpp +++ b/tools/main/main.cpp @@ -138,9 +138,11 @@ int main(int argc, char ** argv) { // load the model and apply lora adapter, if any LOG_INF("%s: load the model and apply lora adapter, if any\n", __func__); - common_init_result llama_init = common_init_from_params(params); - ctx = llama_init.context.get(); - model = llama_init.model.get(); // Update pointer (now managed by llama_init) + auto llama_init = common_init_from_params(params); + + ctx = llama_init->context(); + model = llama_init->model(); + smpl = llama_init->sampler(0); if (ctx == NULL) { LOG_ERR("%s: error: unable to create context\n", __func__); @@ -470,12 +472,6 @@ int main(int argc, char ** argv) { } } - smpl = common_sampler_init(model, sparams); - if (!smpl) { - LOG_ERR("%s: failed to initialize sampling subsystem\n", __func__); - return 1; - } - LOG_INF("sampler seed: %u\n", common_sampler_get_seed(smpl)); LOG_INF("sampler params: \n%s\n", sparams.print().c_str()); LOG_INF("sampler chain: %s\n", common_sampler_print(smpl).c_str()); @@ -989,8 +985,6 @@ int main(int argc, char ** argv) { LOG("\n\n"); common_perf_print(ctx, smpl); - common_sampler_free(smpl); - llama_backend_free(); ggml_threadpool_free_fn(threadpool); diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index 6679de309b..aaec65d8ff 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -65,7 +65,7 @@ static void sigint_handler(int signo) { struct mtmd_cli_context { mtmd::context_ptr ctx_vision; - common_init_result llama_init; + common_init_result_ptr llama_init; llama_model * model; llama_context * lctx; @@ -89,8 +89,8 @@ struct mtmd_cli_context { llama_pos n_past = 0; mtmd_cli_context(common_params & params) : llama_init(common_init_from_params(params)) { - model = llama_init.model.get(); - lctx = llama_init.context.get(); + model = llama_init->model(); + lctx = llama_init->context(); vocab = llama_model_get_vocab(model); smpl = common_sampler_init(model, params.sampling); n_threads = params.cpuparams.n_threads; diff --git a/tools/perplexity/perplexity.cpp b/tools/perplexity/perplexity.cpp index caf080e8d1..1ead9c871e 100644 --- a/tools/perplexity/perplexity.cpp +++ b/tools/perplexity/perplexity.cpp @@ -2024,10 +2024,10 @@ int main(int argc, char ** argv) { llama_numa_init(params.numa); // load the model and apply lora adapter, if any - common_init_result llama_init = common_init_from_params(params); + auto llama_init = common_init_from_params(params); - llama_model * model = llama_init.model.get(); - llama_context * ctx = llama_init.context.get(); + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); if (model == NULL) { LOG_ERR("%s: unable to load model\n", __func__); diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index d083777fa7..c38131b587 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -207,8 +207,8 @@ task_params server_task::params_from_json_cmpl( params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); - const bool request_backend_sampling = json_value(data, "backend_sampling", defaults.sampling.backend_sampling); - params.sampling.backend_sampling = defaults.sampling.backend_sampling && request_backend_sampling; + const bool request_backend_sampling = json_value(data, "backend_sampling", defaults.sampling.backend_sampling); + params.sampling.backend_sampling = defaults.sampling.backend_sampling && request_backend_sampling; params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 48a7cf2c82..0517d21518 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -151,8 +151,7 @@ struct server_slot { // sampling json json_schema; - struct common_sampler * smpl = nullptr; - llama_sampler * backend_sampler = nullptr; + common_sampler_ptr smpl; llama_token sampled; @@ -198,13 +197,6 @@ struct server_slot { n_draft_total = 0; n_draft_accepted = 0; - if (backend_sampler != nullptr) { - if (ctx != nullptr) { - llama_set_backend_sampler(ctx, id, nullptr); - } - backend_sampler = nullptr; - } - task.reset(); task_prev.reset(); @@ -481,8 +473,8 @@ struct server_context { common_params params_base; // note: keep these alive - they determine the lifetime of the model, context, etc. - common_init_result llama_init; - common_init_result llama_init_dft; + common_init_result_ptr llama_init; + common_init_result_ptr llama_init_dft; llama_model * model = nullptr; llama_context * ctx = nullptr; @@ -526,16 +518,6 @@ struct server_context { // Clear any sampling context for (server_slot & slot : slots) { - common_sampler_free(slot.smpl); - slot.smpl = nullptr; - - if (slot.backend_sampler != nullptr) { - if (ctx != nullptr) { - llama_set_backend_sampler(ctx, slot.id, nullptr); - } - slot.backend_sampler = nullptr; - } - llama_free(slot.ctx_dft); slot.ctx_dft = nullptr; @@ -556,8 +538,8 @@ struct server_context { llama_init = common_init_from_params(params_base); - model = llama_init.model.get(); - ctx = llama_init.context.get(); + model = llama_init->model(); + ctx = llama_init->context(); if (model == nullptr) { SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str()); @@ -589,25 +571,25 @@ struct server_context { llama_init_dft = common_init_from_params(params_dft); - model_dft = llama_init_dft.model.get(); + model_dft = llama_init_dft->model(); if (model_dft == nullptr) { SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str()); return false; } - vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft.context.get()); + vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft->context()); if (!vocab_dft_compatible) { SRV_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str()); } - const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get()); + const int n_ctx_dft = llama_n_ctx(llama_init_dft->context()); cparams_dft = common_context_params_to_llama(params_dft); cparams_dft.n_batch = n_ctx_dft; // the context is not needed - we will create one for each slot - llama_init_dft.context.reset(); + llama_init_dft->free_context(); } chat_templates = common_chat_templates_init(model, params_base.chat_template); @@ -1001,23 +983,17 @@ struct server_context { // initialize samplers { - if (slot.smpl != nullptr) { - common_sampler_free(slot.smpl); - } - - slot.smpl = common_sampler_init(model, task.params.sampling); + slot.smpl.reset(common_sampler_init(model, task.params.sampling)); if (slot.smpl == nullptr) { // for now, the only error that may happen here is invalid grammar send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); return false; } - SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl).c_str()); - } + SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str()); - if (!configure_slot_backend_sampler(slot, task.params.sampling)) { - send_error(task, "Failed to configure backend samplers", ERROR_TYPE_SERVER); - return false; + llama_sampler * backend_chain = common_sampler_chain_backend(slot.smpl.get()); + llama_set_backend_sampler(ctx, slot.id, backend_chain); } // initialize draft batch @@ -1037,39 +1013,6 @@ struct server_context { return true; } - bool configure_slot_backend_sampler(server_slot & slot, const common_params_sampling & sampling) { - if (!sampling.backend_sampling) { - if (slot.backend_sampler != nullptr) { - llama_set_backend_sampler(ctx, slot.id, nullptr); - slot.backend_sampler = nullptr; - } - return true; - } - - llama_sampler * backend_chain = common_sampler_backend_init(model, sampling); - // The sampler types configured with --samplers might not be supported - // by backend samplers in which case we disable backend sampling and - // fallback to CPU only sampling. - if (backend_chain == nullptr) { - if (slot.backend_sampler != nullptr) { - llama_set_backend_sampler(ctx, slot.id, nullptr); - slot.backend_sampler = nullptr; - } - SLT_INF(slot, "%s", "no backend samplers configured (sampler chain doesn't start with backend-supported samplers)\n"); - return true; - } - - if (slot.backend_sampler != nullptr) { - llama_set_backend_sampler(ctx, slot.id, nullptr); - slot.backend_sampler = nullptr; - } - - slot.backend_sampler = backend_chain; - llama_set_backend_sampler(ctx, slot.id, backend_chain); - SLT_INF(slot, "%s", "configured backend samplers\n"); - return true; - } - bool process_token(completion_token_output & result, server_slot & slot) { // remember which tokens were sampled - used for repetition penalties during sampling const std::string token_str = result.text_to_send; @@ -1206,7 +1149,7 @@ struct server_context { size_t n_vocab = llama_vocab_n_tokens(vocab); if (post_sampling) { - const auto * cur_p = common_sampler_get_candidates(slot.smpl, true); + const auto * cur_p = common_sampler_get_candidates(slot.smpl.get(), true); const size_t max_probs = cur_p->size; // set probability for sampled token @@ -2185,13 +2128,13 @@ struct server_context { GGML_ASSERT(batch.n_tokens > 0); - common_sampler_reset(slot.smpl); + common_sampler_reset(slot.smpl.get()); // Process all prompt tokens through sampler system for (int i = 0; i < slot.task->n_tokens(); ++i) { llama_token id = input_tokens[i]; if (id != LLAMA_TOKEN_NULL) { - common_sampler_accept(slot.smpl, id, false); + common_sampler_accept(slot.smpl.get(), id, false); } } @@ -2381,11 +2324,11 @@ struct server_context { const int tok_idx = slot.i_batch - i; - llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); + llama_token id = common_sampler_sample(slot.smpl.get(), ctx, tok_idx); slot.i_batch = -1; - common_sampler_accept(slot.smpl, id, true); + common_sampler_accept(slot.smpl.get(), id, true); slot.n_decoded += 1; @@ -2488,7 +2431,7 @@ struct server_context { llama_decode(ctx, slot.batch_spec); // the accepted tokens from the speculation - const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); + const auto ids = common_sampler_sample_and_accept_n(slot.smpl.get(), ctx, draft); slot.n_decoded += ids.size(); diff --git a/tools/tts/tts.cpp b/tools/tts/tts.cpp index eaf56591d9..8c39fce8ba 100644 --- a/tools/tts/tts.cpp +++ b/tools/tts/tts.cpp @@ -568,10 +568,10 @@ int main(int argc, char ** argv) { llama_context * ctx_ttc = NULL; llama_context * ctx_cts = NULL; - common_init_result llama_init_ttc = common_init_from_params(params); + auto llama_init_ttc = common_init_from_params(params); - model_ttc = llama_init_ttc.model.get(); - ctx_ttc = llama_init_ttc.context.get(); + model_ttc = llama_init_ttc->model(); + ctx_ttc = llama_init_ttc->context(); if (model_ttc == nullptr || ctx_ttc == nullptr) { return ENOENT; @@ -583,10 +583,10 @@ int main(int argc, char ** argv) { params.embedding = true; params.n_ubatch = params.n_batch; - common_init_result llama_init_cts = common_init_from_params(params); + auto llama_init_cts = common_init_from_params(params); - model_cts = llama_init_cts.model.get(); - ctx_cts = llama_init_cts.context.get(); + model_cts = llama_init_cts->model(); + ctx_cts = llama_init_cts->context(); if (model_cts == nullptr || ctx_cts == nullptr) { return ENOENT;