diff --git a/common/arg.cpp b/common/arg.cpp index 430ab45dfe..ab3386b1df 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1501,6 +1501,21 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.sampling.grammar = json_schema_to_grammar(json::parse(schema)); } ).set_sparam()); + add_opt(common_arg( + {"--backend-sampling"}, + "enable backend sampling (default: disabled)", + [](common_params & params) { + params.sampling.backend_sampling = true; + } + ).set_sparam()); + add_opt(common_arg( + {"--backend-dist"}, + "perform final (distribution) sampling on backend (default: disabled)", + [](common_params & params) { + params.sampling.backend_dist = true; + params.sampling.backend_sampling = true; + } + ).set_sparam()); add_opt(common_arg( {"--pooling"}, "{none,mean,cls,last,rank}", "pooling type for embeddings, use model default if unspecified", diff --git a/common/common.cpp b/common/common.cpp index 4dc95dcba2..c31619ac36 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -8,6 +8,7 @@ #include "common.h" #include "log.h" #include "llama.h" +#include "sampling.h" #include #include @@ -956,6 +957,8 @@ struct common_init_result common_init_from_params(common_params & params) { const llama_vocab * vocab = llama_model_get_vocab(model); auto cparams = common_context_params_to_llama(params); + cparams.samplers = params.backend_samplers; + cparams.n_samplers = params.n_backend_samplers; llama_context * lctx = llama_init_from_model(model, cparams); if (lctx == NULL) { diff --git a/common/common.h b/common/common.h index f42c083faa..b320d891f5 100644 --- a/common/common.h +++ b/common/common.h @@ -188,6 +188,10 @@ struct common_params_sampling { std::vector logit_bias; // logit biases to apply std::vector logit_bias_eog; // pre-calculated logit biases for EOG tokens + // Backend sampling flags + bool backend_sampling = false; // enable backend sampling + bool backend_dist = false; // backend performs final sampling (dist) + // print the parameters into a string std::string print() const; }; @@ -512,6 +516,9 @@ struct common_params { bool has_speculative() const { return !speculative.model.path.empty() || !speculative.model.hf_repo.empty(); } + + struct llama_sampler_seq_config * backend_samplers = NULL; + size_t n_backend_samplers = 0; }; // call once at the start of a program if it uses libcommon diff --git a/common/llguidance.cpp b/common/llguidance.cpp index adce620e4d..27d15516e9 100644 --- a/common/llguidance.cpp +++ b/common/llguidance.cpp @@ -106,12 +106,16 @@ static void llama_sampler_llg_free(llama_sampler * smpl) { } static llama_sampler_i llama_sampler_llg_i = { - /* .name = */ llama_sampler_llg_name, - /* .accept = */ llama_sampler_llg_accept_impl, - /* .apply = */ llama_sampler_llg_apply, - /* .reset = */ llama_sampler_llg_reset, - /* .clone = */ llama_sampler_llg_clone, - /* .free = */ llama_sampler_llg_free, + /* .name = */ llama_sampler_llg_name, + /* .accept = */ llama_sampler_llg_accept_impl, + /* .apply = */ llama_sampler_llg_apply, + /* .reset = */ llama_sampler_llg_reset, + /* .clone = */ llama_sampler_llg_clone, + /* .free = */ llama_sampler_llg_free, + /* .apply_ggml = */ NULL, + /* .accept_ggml = */ NULL, + /* .set_input_ggml = */ NULL, + /* .set_backend_context = */ NULL, }; static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len, diff --git a/common/sampling.cpp b/common/sampling.cpp index c69d525b5b..1fc5c7ce0a 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -113,23 +113,61 @@ struct common_sampler { llama_token_data_array cur_p; void set_logits(struct llama_context * ctx, int idx) { - const auto * logits = llama_get_logits_ith(ctx, idx); + const float * sampled_probs = llama_get_backend_sampled_probs_ith(ctx, idx); + const float * sampled_logits = llama_get_backend_sampled_logits_ith(ctx, idx); + const llama_token * sampled_ids = llama_get_backend_sampled_token_ids_ith(ctx, idx); const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); const int n_vocab = llama_vocab_n_tokens(vocab); - cur.resize(n_vocab); + // Use the member variable instead of allocating locally + cur.clear(); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + if (sampled_probs) { + const uint32_t sampled_probs_count = llama_get_backend_sampled_probs_count_ith(ctx, idx); + cur.reserve(sampled_probs_count); + // The backend sampler has filtered the probabilities so we need to use the sampled ids. + if (sampled_ids != nullptr) { + for (uint32_t i = 0; i < sampled_probs_count; ++i) { + cur.emplace_back(llama_token_data{sampled_ids[i], 0.0f, sampled_probs[i]}); + } + } else { + for (llama_token token_id = 0; token_id < (int) sampled_probs_count; token_id++) { + cur.emplace_back(llama_token_data{token_id, 0.0f, sampled_probs[token_id]}); + } + } + } else if (sampled_logits) { + const uint32_t sampled_logits_count = llama_get_backend_sampled_logits_count_ith(ctx, idx); + cur.reserve(sampled_logits_count); + // The backend sampler has filtered the logits so we need to use the sampled ids. + if (sampled_ids != nullptr) { + for (llama_token i = 0; i < (int)sampled_logits_count; i++) { + cur.emplace_back(llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f}); + } + } else { + for (llama_token token_id = 0; token_id < (int)sampled_logits_count; token_id++) { + cur.emplace_back(llama_token_data{token_id, sampled_logits[token_id], 0.0f}); + } + } + } else { + const auto * logits = llama_get_logits_ith(ctx, idx); + GGML_ASSERT(logits != nullptr); + cur.reserve(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + } } cur_p = { cur.data(), cur.size(), -1, false }; } }; +static bool sampler_enabled(const struct common_params_sampling & params, enum common_sampler_type type) { + return std::find(params.samplers.begin(), params.samplers.end(), type) != params.samplers.end(); +} + std::string common_params_sampling::print() const { char result[1024]; @@ -287,6 +325,43 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co return result; } +struct llama_sampler * common_sampler_backend_init(const struct llama_model * model, const struct common_params_sampling & params) { + const llama_vocab * vocab = llama_model_get_vocab(model); + + llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + chain_params.no_perf = params.no_perf; + + struct llama_sampler * chain = llama_sampler_chain_init(chain_params); + if (!params.backend_sampling) { + return chain; // return empty chain + } + + const bool enable_temp = params.temp > 0.0f && sampler_enabled(params, COMMON_SAMPLER_TYPE_TEMPERATURE); + const bool enable_top_k = params.top_k > 0 && sampler_enabled(params, COMMON_SAMPLER_TYPE_TOP_K); + const bool enable_dist = params.backend_dist; + + if (!params.logit_bias.empty()) { + llama_sampler_chain_add(chain, llama_sampler_backend_init_logit_bias( + llama_vocab_n_tokens(vocab), + params.logit_bias.size(), + params.logit_bias.data())); + } + + if (enable_temp) { + llama_sampler_chain_add(chain, llama_sampler_backend_init_temp(params.temp)); + } + + if (enable_top_k) { + llama_sampler_chain_add(chain, llama_sampler_backend_init_top_k(params.top_k)); + } + + if (enable_dist) { + llama_sampler_chain_add(chain, llama_sampler_backend_init_dist(params.seed)); + } + + return chain; +} + void common_sampler_free(struct common_sampler * gsmpl) { if (gsmpl) { llama_sampler_free(gsmpl->grmr); @@ -337,6 +412,14 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam } llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) { + // Check if a backend sampler has already sampled a token in which case we + // return that token id directly. + const llama_token backend_sampled_token = llama_get_backend_sampled_token_ith(ctx, idx); + if (backend_sampled_token != LLAMA_TOKEN_NULL) { + LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, backend_sampled_token); + return backend_sampled_token; + } + gsmpl->set_logits(ctx, idx); auto & grmr = gsmpl->grmr; diff --git a/common/sampling.h b/common/sampling.h index e198eecda3..0ec164de05 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -38,6 +38,13 @@ struct common_sampler; struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params); +// Create a backend sampler chain from common sampling parameters +// Returns a llama_sampler chain configured with backend samplers based on the parameters +// This chain can be used per-sequence for backend-based sampling +// Note: Only samplers that have backend equivalents will be added to the chain +// The returned sampler should be freed with llama_sampler_free() +struct llama_sampler * common_sampler_backend_init(const struct llama_model * model, const struct common_params_sampling & params); + void common_sampler_free(struct common_sampler * gsmpl); // if accept_grammar is true, the token is accepted both by the sampling chain and the grammar diff --git a/include/llama.h b/include/llama.h index 8547226ff2..cbf23c7bcf 100644 --- a/include/llama.h +++ b/include/llama.h @@ -210,6 +210,13 @@ extern "C" { bool sorted; // note: do not assume the data is sorted - always check this flag } llama_token_data_array; + struct llama_sampler_ggml_data { + struct ggml_tensor * logits; + struct ggml_tensor * probs; + struct ggml_tensor * sampled_token; + struct ggml_tensor * filtered_ids; + }; + typedef bool (*llama_progress_callback)(float progress, void * user_data); // Input data for llama_encode/llama_decode @@ -300,6 +307,11 @@ extern "C" { bool no_host; // bypass host buffer allowing extra buffers to be used }; + struct llama_sampler_seq_config { + llama_seq_id seq_id; + struct llama_sampler * sampler; + }; + // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations // https://github.com/ggml-org/llama.cpp/pull/7544 struct llama_context_params { @@ -348,6 +360,10 @@ extern "C" { bool kv_unified; // use a unified buffer across the input sequences when computing the attention // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix // ref: https://github.com/ggml-org/llama.cpp/pull/14363 + + // backend sampler chain configuration + struct llama_sampler_seq_config * samplers; + size_t n_samplers; }; // model quantization parameters @@ -950,6 +966,29 @@ extern "C" { // otherwise: float[n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); + // Get the backend sampled token for the ith token. + // Returns LLAMA_TOKEN_NULL if no token was sampled. + LLAMA_API llama_token llama_get_backend_sampled_token_ith(struct llama_context * ctx, int32_t i); + + // Get the backend sampled probabilites for the ith token + // The index matches llama_get_backend_sampled_token_ith(). + // Returns NULL if no probabilites were generated. + LLAMA_API float * llama_get_backend_sampled_probs_ith(struct llama_context * ctx, int32_t i); + + // Get the backend sampled logits for the ith token + // Returns NULL if no logits were sampled. + LLAMA_API float * llama_get_backend_sampled_logits_ith(struct llama_context * ctx, int32_t i); + + // Get the backend sampled token ids associated with the sampled logits for the ith token + // Returns NULL if no logits were sampled. + LLAMA_API llama_token * llama_get_backend_sampled_token_ids_ith(struct llama_context * ctx, int32_t i); + + // Get the number of backend sampled logits for the ith token. + LLAMA_API uint32_t llama_get_backend_sampled_logits_count_ith(struct llama_context * ctx, int32_t i); + + // Get the number of backend sampled probabilites for the ith token. + LLAMA_API uint32_t llama_get_backend_sampled_probs_count_ith(struct llama_context * ctx, int32_t i); + // // Vocab // @@ -1135,6 +1174,22 @@ extern "C" { struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL + void (*apply_ggml)( struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_ggml_data * ggml_data); + + void (*accept_ggml)( struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct ggml_tensor * selected_token); + + void (*set_input_ggml)(struct llama_sampler * smpl); + + void (*init_ggml)(struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft); + + // TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph //void (*apply_ggml) (struct llama_sampler * smpl, ...); }; @@ -1144,6 +1199,8 @@ extern "C" { llama_sampler_context_t ctx; }; + LLAMA_API void llama_set_backend_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl); + // mirror of llama_sampler_i: LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx); LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl); @@ -1153,6 +1210,18 @@ extern "C" { LLAMA_API struct llama_sampler * llama_sampler_clone (const struct llama_sampler * smpl); // important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add) LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl); + LLAMA_API void llama_sampler_init_ggml(struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft); + LLAMA_API void llama_sampler_set_input_ggml(struct llama_sampler * smpl); + LLAMA_API void llama_sampler_apply_ggml( struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_ggml_data * ggml_data); + + LLAMA_API void llama_sampler_accept_ggml( struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct ggml_tensor * selected_token); // llama_sampler_chain // a type of llama_sampler that can chain multiple samplers one after another @@ -1166,6 +1235,7 @@ extern "C" { // after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed LLAMA_API struct llama_sampler * llama_sampler_chain_remove( struct llama_sampler * chain, int32_t i); + LLAMA_API uint64_t llama_sampler_chain_get_version(const struct llama_sampler * chain); // available samplers: @@ -1299,9 +1369,29 @@ extern "C" { // LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab); + // + // Backend samplers + // + + /// @details Greedy sampling on backend - always selects the token with the highest probability + LLAMA_API struct llama_sampler * llama_sampler_backend_init_greedy(void); + + /// @details Temperature scaling on backend - scales logits by 1/temperature + LLAMA_API struct llama_sampler * llama_sampler_backend_init_temp(float temp); + + /// @details Top-K filtering on backend - keeps only the k tokens with highest probabilities + LLAMA_API struct llama_sampler * llama_sampler_backend_init_top_k(int32_t k); + + /// @details Distribution sampling on backend - final sampling step that selects a token + LLAMA_API struct llama_sampler * llama_sampler_backend_init_dist(uint32_t seed); + // Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl); + LLAMA_API struct llama_sampler * llama_sampler_backend_init_logit_bias(int32_t n_vocab, + int32_t n_logit_bias, + const llama_logit_bias * logit_bias); + /// @details Sample and accept a token from the idx-th output of the last evaluation // // Shorthand for: diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 8ec95ee176..c17b890089 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -31,6 +31,7 @@ add_library(llama llama-model.cpp llama-quant.cpp llama-sampling.cpp + llama-backend-sampler.cpp llama-vocab.cpp unicode-data.cpp unicode.cpp diff --git a/src/llama-backend-sampler.cpp b/src/llama-backend-sampler.cpp new file mode 100644 index 0000000000..42c8d85aeb --- /dev/null +++ b/src/llama-backend-sampler.cpp @@ -0,0 +1,489 @@ +#include "llama.h" +#include "ggml.h" +#include +#include +#include +#include +#include + +static void llama_sampler_backend_greedy_apply_ggml( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_ggml_data * ggml_data) { + GGML_UNUSED(gf); + GGML_UNUSED(smpl); + struct ggml_tensor * argmax_result = ggml_argmax(ctx, ggml_data->logits); + ggml_set_name(argmax_result, "argmax_result"); + ggml_data->sampled_token = argmax_result; +} + +static const char * llama_sampler_backend_greedy_sampler_name(const struct llama_sampler *) { + return "test-ggml"; +} + +static struct llama_sampler * llama_sampler_backend_greedy_clone(const struct llama_sampler * smpl) { + (void) smpl; + return llama_sampler_backend_init_greedy(); +} + +struct llama_sampler * llama_sampler_backend_init_greedy() { + static const llama_sampler_i iface = { + /*.name =*/ llama_sampler_backend_greedy_sampler_name, + /*.accept =*/ nullptr, + /*.apply =*/ nullptr, + /*.reset =*/ nullptr, + /*.clone =*/ llama_sampler_backend_greedy_clone, + /*.free =*/ nullptr, + /*.apply_ggml =*/ llama_sampler_backend_greedy_apply_ggml, + /*.accept_ggml =*/ nullptr, + /*.set_input_ggml =*/ nullptr, + /*.init_ggml =*/ nullptr, + }; + + auto * sampler = new llama_sampler { + /*.iface =*/ &iface, + /*.ctx =*/ nullptr, + }; + + return sampler; +} + +struct llama_sampler_backend_temp_ctx { + float temp; +}; + + +static void llama_sampler_backend_temp_apply_ggml( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_ggml_data * ggml_data) { + + auto * ctx_data = (llama_sampler_backend_temp_ctx *) smpl->ctx; + + if (ctx_data->temp <= 0.0f) { + return; + } + + struct ggml_tensor * scaled = ggml_scale(ctx, ggml_data->logits, 1.0f / ctx_data->temp); + ggml_set_name(scaled, "temp_scaled"); + + // Make sure the scaled tensor is contiguous for subsequent operations + ggml_data->logits = ggml_cont(ctx, scaled); + ggml_set_name(ggml_data->logits, "temp_scaled_logits"); + + ggml_build_forward_expand(gf, ggml_data->logits); +} + +static const char * llama_sampler_backend_temp_name(const struct llama_sampler *) { + return "backend-temp"; +} + +static void llama_sampler_backend_temp_free(struct llama_sampler * smpl) { + auto * ctx_data = (llama_sampler_backend_temp_ctx *) smpl->ctx; + delete ctx_data; +} + +static struct llama_sampler * llama_sampler_backend_temp_clone(const struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_backend_temp_ctx *) smpl->ctx; + return llama_sampler_backend_init_temp(ctx->temp); +} + +struct llama_sampler * llama_sampler_backend_init_temp(float temp) { + static const llama_sampler_i iface = { + /*.name =*/ llama_sampler_backend_temp_name, + /*.accept =*/ nullptr, + /*.apply =*/ nullptr, + /*.reset =*/ nullptr, + /*.clone =*/ llama_sampler_backend_temp_clone, + /*.free =*/ llama_sampler_backend_temp_free, + /*.apply_ggml =*/ llama_sampler_backend_temp_apply_ggml, + /*.accept_ggml =*/ nullptr, + /*.set_input_ggml =*/ nullptr, + /*.set_backend_context =*/ nullptr, + }; + + auto * ctx_data = new llama_sampler_backend_temp_ctx { + /*.temp =*/ temp, + }; + + auto * sampler = new llama_sampler { + /*.iface =*/ &iface, + /*.ctx =*/ ctx_data, + }; + + return sampler; +} + + +struct llama_sampler_backend_top_k_ctx { + int32_t k; + + // Only required for checking operation support and can be removed later. + ggml_backend_dev_t device; +}; + +static void llama_sampler_backend_top_k_init_ggml( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * ctx_data = (llama_sampler_backend_top_k_ctx *) smpl->ctx; + ctx_data->device = ggml_backend_buft_get_device(buft); +} + +static void llama_sampler_backend_top_k_apply_ggml( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_ggml_data * ggml_data) { + + auto * ctx_data = (llama_sampler_backend_top_k_ctx *) smpl->ctx; + + struct ggml_tensor * top_k = ggml_top_k(ctx, ggml_data->logits, ctx_data->k); + ggml_set_name(top_k, "top_k"); + + // top_k is a view of argsort - check if backend supports the underlying argsort operation + // by checking the source tensor (which is the argsort result) + if (ctx_data->device && top_k->src[0] && !ggml_backend_dev_supports_op(ctx_data->device, top_k->src[0])) { + fprintf(stderr, "Warning: backend does not support argsort operation required for top-k sampling\n"); + fprintf(stderr, "CPU backend will be used instead which defeats the purpose of having backend samplers\n"); + } + + ggml_data->filtered_ids = top_k; + + struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, ggml_data->logits, 1, ggml_data->logits->ne[0]); + struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k); + ggml_set_name(top_k_rows, "top_k_rows"); + + ggml_data->logits = ggml_reshape_1d(ctx, top_k_rows, ctx_data->k); + ggml_build_forward_expand(gf, ggml_data->logits); +} + +static const char * llama_sampler_backend_top_k_name(const struct llama_sampler *) { + return "backend-top-k"; +} + +static void llama_sampler_backend_top_k_free(struct llama_sampler * smpl) { + auto * ctx_data = (llama_sampler_backend_top_k_ctx *) smpl->ctx; + delete ctx_data; +} + +static struct llama_sampler * llama_sampler_backend_top_k_clone(const struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_backend_top_k_ctx *) smpl->ctx; + return llama_sampler_backend_init_top_k(ctx->k); +} + +struct llama_sampler * llama_sampler_backend_init_top_k(int32_t k) { + static const llama_sampler_i iface = { + /*.name =*/ llama_sampler_backend_top_k_name, + /*.accept =*/ nullptr, + /*.apply =*/ nullptr, + /*.reset =*/ nullptr, + /*.clone =*/ llama_sampler_backend_top_k_clone, + /*.free =*/ llama_sampler_backend_top_k_free, + /*.apply_ggml =*/ llama_sampler_backend_top_k_apply_ggml, + /*.accept_ggml =*/ nullptr, + /*.set_input_ggml =*/ nullptr, + /*.init_ggml =*/ llama_sampler_backend_top_k_init_ggml, + }; + + auto * ctx_data = new llama_sampler_backend_top_k_ctx { + /*.k =*/ k, + /*.device =*/ nullptr, + }; + + auto * sampler = new llama_sampler { + /*.iface =*/ &iface, + /*.ctx =*/ ctx_data, + }; + + return sampler; +} + + +static uint32_t get_rng_seed(uint32_t seed) { + if (seed == LLAMA_DEFAULT_SEED) { + // use system clock if std::random_device is not a true RNG + static bool is_rd_prng = std::random_device().entropy() == 0; + if (is_rd_prng) { + return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count(); + } + std::random_device rd; + return rd(); + } + return seed; +} + +struct llama_sampler_backend_dist_ctx { + const uint32_t seed; + uint32_t seed_cur; + std::mt19937 rng; + + struct ggml_tensor * uniform; + struct ggml_context * ctx; + ggml_backend_buffer_t buffer; + + // Only required for checking operation support and can be removed later. + ggml_backend_dev_t device; +}; + +static void llama_sampler_backend_dist_init_ggml( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + + auto * sctx = (llama_sampler_backend_dist_ctx *) smpl->ctx; + sctx->device = ggml_backend_buft_get_device(buft); + ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + sctx->ctx = ggml_init(params); + + // Create the uniform random scalar input tensor. This will be set by + // llama_sampler_backend_dist_set_input_ggml after this graph is built. + sctx->uniform = ggml_new_tensor_1d(sctx->ctx, GGML_TYPE_F32, 1); + ggml_set_name(sctx->uniform, "uniform"); + ggml_set_input(sctx->uniform); + ggml_set_output(sctx->uniform); + + // Allocate all tensors from our context to the backend + sctx->buffer = ggml_backend_alloc_ctx_tensors_from_buft(sctx->ctx, buft); +} + +static void llama_sampler_backend_dist_set_input_ggml(struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_backend_dist_ctx *) smpl->ctx; + GGML_ASSERT(sctx->uniform != nullptr); + + std::uniform_real_distribution dist(0.0f, 1.0f); + const float rnd = dist(sctx->rng); + ggml_backend_tensor_set(sctx->uniform, &rnd, 0, sizeof(float)); +} + +static void llama_sampler_backend_dist_apply_ggml( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_ggml_data * ggml_data) { + GGML_UNUSED(gf); + auto * sctx = (llama_sampler_backend_dist_ctx *) smpl->ctx; + + struct ggml_tensor * probs = ggml_soft_max(ctx, ggml_data->logits); + ggml_set_name(probs, "dist_probs"); + + struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs); + if (sctx->device && !ggml_backend_dev_supports_op(sctx->device, cumsum)) { + fprintf(stderr, "Warning: backend does not support cumsum operation required for dist sampling\n"); + fprintf(stderr, "CPU backend will be used instead which defeats the purpose of having backend samplers\n"); + } + ggml_set_name(cumsum, "cumsum"); + + // The uniform tensor has a random value and we subtract this tensor with + // the cumsum tensor (the uniform tensor will be broadcasted by ggml_sub). + // Recall that each entry in cumsum is the cumulative probability up to that + // index so values stay negative while the cumulative total is below the + // random value, and become zero/positive once the threshold is crossed. + struct ggml_tensor * diff = ggml_sub(ctx, cumsum, sctx->uniform); + ggml_set_name(diff, "dist_cumsum"); + + // The ggml_step function produces a tensor where entries are 1 if the + // corresponding entry in diff is > 0, and 0 otherwise. So all values up to + // the index where the cumulative probability exceeds the random value are 0, + // and all entries after that are 1. + struct ggml_tensor * mask = ggml_step(ctx, diff); + ggml_set_name(mask, "dist_mask"); + + // Taking the sum of the mask gives us the sum of elements after the threshold + // we are interested in. + struct ggml_tensor * idxf = ggml_sum(ctx, mask); + ggml_set_name(idxf, "dist_index_f32"); + + // Use ggml_scale_bias to scale the index value by -1 and then add the size + // of the mask to that value so we get the correct index ((-1 * idxf) + n). + struct ggml_tensor * idx = ggml_cast(ctx, ggml_scale_bias(ctx, idxf, -1.0f, mask->ne[0]), GGML_TYPE_I32); + ggml_set_name(idx, "dist_index_i32"); + + // Map back to original vocab ids if a filtered id tensor is available. + struct ggml_tensor * sampled_token = idx; + if (ggml_data->filtered_ids != nullptr) { + struct ggml_tensor * filtered_ids = ggml_data->filtered_ids; + struct ggml_tensor * filtered_ids_reshaped = ggml_view_2d(ctx, filtered_ids, 1, ggml_nelements(filtered_ids), + ggml_type_size(filtered_ids->type), 0); + + sampled_token = ggml_get_rows(ctx, filtered_ids_reshaped, idx); + ggml_set_name(sampled_token, "dist_sampled_token"); + } + + ggml_set_output(sampled_token); + ggml_data->sampled_token = sampled_token; +} + +static const char * llama_sampler_backend_dist_name(const struct llama_sampler *) { + return "backend-dist"; +} + +static void llama_sampler_backend_dist_free(struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_backend_dist_ctx *) smpl->ctx; + ggml_backend_buffer_free(sctx->buffer); + ggml_free(sctx->ctx); + delete sctx; +} + +static struct llama_sampler * llama_sampler_backend_dist_clone(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_backend_dist_ctx *) smpl->ctx; + return llama_sampler_backend_init_dist(sctx->seed); +} + + +struct llama_sampler * llama_sampler_backend_init_dist(uint32_t seed) { + static const llama_sampler_i iface = { + /*.name =*/ llama_sampler_backend_dist_name, + /*.accept =*/ nullptr, + /*.apply =*/ nullptr, + /*.reset =*/ nullptr, + /*.clone =*/ llama_sampler_backend_dist_clone, + /*.free =*/ llama_sampler_backend_dist_free, + /*.apply_ggml =*/ llama_sampler_backend_dist_apply_ggml, + /*.accept_ggml =*/ nullptr, + /*.set_input_ggml =*/ llama_sampler_backend_dist_set_input_ggml, + /*.init_ggml =*/ llama_sampler_backend_dist_init_ggml, + }; + + auto seed_cur = get_rng_seed(seed); + auto * ctx_data = new llama_sampler_backend_dist_ctx { + /*.seed =*/ seed, + /*.seed_cur =*/ seed_cur, + /*.rng =*/ std::mt19937(seed_cur), + /*.uniform =*/ nullptr, + /*.ctx =*/ nullptr, + /*.buffer =*/ nullptr, + /*.device =*/ nullptr, + }; + + auto * sampler = new llama_sampler { + /*.iface =*/ &iface, + /*.ctx =*/ ctx_data, + }; + + return sampler; +} + +struct llama_sampler_backend_logit_bias_ctx { + const int32_t n_vocab; + + const std::vector logit_bias; + + struct ggml_tensor * logit_bias_t; + struct ggml_context * ctx; + ggml_backend_buffer_t buffer; +}; + +static void llama_sampler_backend_logit_bias_init_ggml( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_backend_logit_bias_ctx *) smpl->ctx; + if (sctx->logit_bias.empty()) { + return; + } + ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead() * sctx->n_vocab * sizeof(float), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + sctx->ctx = ggml_init(params); + + struct ggml_tensor * logit_bias = ggml_new_tensor_1d(sctx->ctx, GGML_TYPE_F32, sctx->n_vocab); + sctx->logit_bias_t = logit_bias; + ggml_set_name(sctx->logit_bias_t, "logit_bias"); + ggml_set_input(sctx->logit_bias_t); + ggml_set_output(sctx->logit_bias_t); + + // Allocate all tensors from our context to the backend + sctx->buffer = ggml_backend_alloc_ctx_tensors_from_buft(sctx->ctx, buft); +} + +static void llama_sampler_backend_logit_bias_set_input_ggml(struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_backend_logit_bias_ctx *) smpl->ctx; + if (sctx->logit_bias.empty()) { + return; + } + GGML_ASSERT(sctx->logit_bias_t != nullptr); + + // Create a sparse logit_bias vector from the logit_bias entries. + std::vector logit_bias_sparse(sctx->n_vocab, 0.0f); + for (const auto & lb : sctx->logit_bias) { + GGML_ASSERT(lb.token >= 0 && lb.token < (int32_t) sctx->n_vocab); + logit_bias_sparse[lb.token] = lb.bias; + } + + ggml_backend_tensor_set(sctx->logit_bias_t, logit_bias_sparse.data(), 0, ggml_nbytes(sctx->logit_bias_t)); +} + +static void llama_sampler_backend_logit_bias_apply_ggml( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_ggml_data * ggml_data) { + GGML_UNUSED(gf); + GGML_UNUSED(ctx); + + auto * sctx = (llama_sampler_backend_logit_bias_ctx *) smpl->ctx; + if (sctx->logit_bias_t == nullptr) { + return; + } + + // Add the sparse logit logit_bias to the logits + struct ggml_tensor * logit_biased = ggml_add_inplace(sctx->ctx, ggml_data->logits, sctx->logit_bias_t); + ggml_build_forward_expand(gf, logit_biased); +} + +static const char * llama_sampler_backend_logit_bias_name(const struct llama_sampler *) { + return "backend-logit_bias"; +} + +static void llama_sampler_backend_logit_bias_free(struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_backend_logit_bias_ctx *) smpl->ctx; + ggml_backend_buffer_free(sctx->buffer); + ggml_free(sctx->ctx); + delete sctx; +} + +static struct llama_sampler * llama_sampler_backend_logit_bias_clone(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_backend_logit_bias_ctx *) smpl->ctx; + return llama_sampler_backend_init_logit_bias(sctx->n_vocab, + sctx->logit_bias.size(), + sctx->logit_bias.data()); +} + + +struct llama_sampler * llama_sampler_backend_init_logit_bias(int32_t n_vocab, + int32_t n_logit_bias, + const llama_logit_bias * logit_bias) { + static const llama_sampler_i iface = { + /*.name =*/ llama_sampler_backend_logit_bias_name, + /*.accept =*/ nullptr, + /*.apply =*/ nullptr, + /*.reset =*/ nullptr, + /*.clone =*/ llama_sampler_backend_logit_bias_clone, + /*.free =*/ llama_sampler_backend_logit_bias_free, + /*.apply_ggml =*/ llama_sampler_backend_logit_bias_apply_ggml, + /*.accept_ggml =*/ nullptr, + /*.set_input_ggml =*/ llama_sampler_backend_logit_bias_set_input_ggml, + /*.init_ggml =*/ llama_sampler_backend_logit_bias_init_ggml, + }; + + auto * ctx_data = new llama_sampler_backend_logit_bias_ctx { + /*.n_vocab =*/ n_vocab, + /*.logit_bias =*/ std::vector(logit_bias, logit_bias + n_logit_bias), + /*.logit_bias_t =*/ nullptr, + /*.ctx =*/ nullptr, + /*.buffer =*/ nullptr, + }; + + auto * sampler = new llama_sampler { + /*.iface =*/ &iface, + /*.ctx =*/ ctx_data, + }; + + return sampler; +} diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 70a3ec62df..877116cbfe 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -58,6 +58,16 @@ llama_context::llama_context( cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; + // backend samplers + if (params.samplers != nullptr && params.n_samplers > 0) { + samplers.reserve(params.n_samplers); + + for (size_t i = 0; i < params.n_samplers; ++i) { + const auto & config = params.samplers[i]; + samplers[config.seq_id] = config.sampler; + } + } + auto rope_scaling_type = params.rope_scaling_type; if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) { rope_scaling_type = hparams.rope_scaling_type_train; @@ -424,6 +434,10 @@ llama_context::llama_context( llama_context::~llama_context() { ggml_opt_free(opt_ctx); + // TODO: perhaps use a smart pointer for samplers + for (auto const& [seq_id, sampler] : samplers) { + llama_sampler_free(sampler); + } } void llama_context::synchronize() { @@ -610,6 +624,10 @@ float * llama_context::get_embeddings() { return embd; } +llama_token * llama_context::get_backend_sampled_tokens() { + return sampled_tokens; +} + float * llama_context::get_embeddings_ith(int32_t i) { int64_t j = -1; @@ -659,6 +677,98 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { return it->second.data(); } +llama_token llama_context::get_backend_sampled_token_ith(int32_t idx) { + // Handle special case where idx == -1 (single sequence exists) which is + // a valid index when using common_sampler_sample. + if (idx == -1) { + if (sampled_tokens_map.size() == 1) { + auto it = sampled_tokens_map.begin(); + return it->second; + } + return LLAMA_TOKEN_NULL; + } + + auto it = sampled_tokens_map.find(idx); + if (it == sampled_tokens_map.end()) { + return LLAMA_TOKEN_NULL; + } + + return it->second; +} + +float * llama_context::get_backend_sampled_probs_ith(int32_t idx) { + if (idx == -1) { + if (sampled_probs_map.size() == 1) { + return sampled_probs_map.begin()->second.data(); + } + } + + auto it = sampled_probs_map.find(idx); + if (it == sampled_probs_map.end()) { + return nullptr; + } + + return it->second.data(); +} + +float * llama_context::get_backend_sampled_logits_ith(int32_t idx) { + if (idx == -1) { + if (sampled_logits_map.size() == 1) { + return sampled_logits_map.begin()->second.data(); + } + } + auto it = sampled_logits_map.find(idx); + if (it == sampled_logits_map.end()) { + return nullptr; + } + + return it->second.data(); +} + +const llama_token * llama_context::get_backend_sampled_token_ids_ith(int32_t idx) { + if (idx == -1) { + if (sampled_token_ids_map.size() == 1) { + return sampled_token_ids_map.begin()->second.data(); + } + } + auto it = sampled_token_ids_map.find(idx); + if (it == sampled_token_ids_map.end() || it->second.empty()) { + return nullptr; + } + + return it->second.data(); +} + +size_t llama_context::get_backend_sampled_logits_count(int32_t idx) const { + if (idx == -1) { + if (sampled_logits_map.size() == 1) { + return sampled_logits_map.begin()->second.size(); + } + } + auto it = sampled_logits_map.find(idx); + if (it == sampled_logits_map.end()) { + return 0; + } + + return it->second.size(); +} + +size_t llama_context::get_backend_sampled_probs_count(int32_t idx) const { + if (idx == -1) { + if (sampled_probs_map.size() == 1) { + return sampled_probs_map.begin()->second.size(); + } + return 0; + } + + auto it = sampled_probs_map.find(idx); + if (it == sampled_probs_map.end()) { + return 0; + } + + return it->second.size(); +} + void llama_context::attach_threadpool( ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch) { @@ -715,6 +825,37 @@ void llama_context::set_warmup(bool value) { cparams.warmup = value; } +void llama_context::set_backend_sampler(llama_seq_id seq_id, llama_sampler * sampler) { + LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler); + + auto it = samplers.find(seq_id); + if (it != samplers.end()) { + // If the sampler to be set is the same that is already set, do nothing. + if (it->second == sampler) { + return; + } + + llama_sampler_free(it->second); + + // If sampler is nullptr, we remove the samppler chain for this seq_id. + // chain for this seq_id. + if (sampler == nullptr) { + samplers.erase(it); + return; + } + + // Otherwise, we replace the existing sampler with the new one. + it->second = sampler; + return; + } + + // If there is no sampler for this seq_id and the caller provides a non-null + // sampler, we set it. + if (sampler != nullptr) { + samplers[seq_id] = sampler; + } +} + void llama_context::set_adapter_lora( llama_adapter_lora * adapter, float scale) { @@ -1029,6 +1170,10 @@ int llama_context::decode(const llama_batch & batch_inp) { // TODO: this clear of the buffer can easily be forgotten - need something better embd_seq.clear(); + sampled_probs_map.clear(); + sampled_logits_map.clear(); + sampled_tokens_map.clear(); + sampled_token_ids_map.clear(); output_swaps.clear(); bool did_optimize = false; @@ -1088,6 +1233,10 @@ int llama_context::decode(const llama_batch & batch_inp) { }; int64_t n_outputs_prev = 0; + // This flag indicates whether a backend sampler has actually sampled a specific + // token, or if it has produced probabilites. If true, we true we can skip + // the normal copying of logits and embeddings. + bool backend_has_sampled = false; do { const auto & ubatch = mctx->get_ubatch(); @@ -1147,80 +1296,131 @@ int llama_context::decode(const llama_batch & batch_inp) { // ggml_graph_dump_dot(gf, NULL, "llama.dot"); //} - auto * t_logits = res->get_logits(); - auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; - - if (t_embd && res->get_embd_pooled()) { - t_embd = res->get_embd_pooled(); - } - - // extract logits - if (t_logits && n_outputs > 0) { - ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); - GGML_ASSERT(backend_res != nullptr); - GGML_ASSERT(logits != nullptr); - - float * logits_out = logits + n_outputs_prev*n_vocab; - - if (n_outputs) { - GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size); - ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); + std::unordered_map seq_to_idx; + for (uint32_t i = 0; i < ubatch.n_tokens; i++) { + if (ubatch.output[i]) { + llama_seq_id seq_id = ubatch.seq_id[i][0]; + seq_to_idx[seq_id] = i; } } - // extract embeddings - if (t_embd && n_outputs > 0) { - ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); - GGML_ASSERT(backend_embd != nullptr); + // extract sampled tokens + for (const auto & [seq_id, t_token] : res->t_sampled_tokens) { + auto idx_it = seq_to_idx.find(seq_id); + GGML_ASSERT(idx_it != seq_to_idx.end()); + const int32_t idx = idx_it->second; + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched.get(), t_token); + ggml_backend_tensor_get_async(backend, t_token, &sampled_tokens_map[idx], 0, sizeof(llama_token)); + } - switch (cparams.pooling_type) { - case LLAMA_POOLING_TYPE_NONE: - { - // extract token embeddings - GGML_ASSERT(embd != nullptr); - float * embd_out = embd + n_outputs_prev*n_embd; + for (const auto & [seq_id, t_ids] : res->t_sampled_token_ids) { + auto idx_it = seq_to_idx.find(seq_id); + GGML_ASSERT(idx_it != seq_to_idx.end()); + const int32_t idx = idx_it->second; + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched.get(), t_ids); + sampled_token_ids_map[idx].resize(ggml_nelements(t_ids)); + ggml_backend_tensor_get_async(backend, t_ids, sampled_token_ids_map[idx].data(), 0, ggml_nbytes(t_ids)); + } - if (n_outputs) { - GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float)); + if (res->t_sampled_tokens.empty()) { + for (const auto & [seq_id, t_logits] : res->t_sampled_logits) { + auto idx_it = seq_to_idx.find(seq_id); + GGML_ASSERT(idx_it != seq_to_idx.end()); + const int32_t idx = idx_it->second; + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); + sampled_logits_map[idx].resize(ggml_nelements(t_logits)); + ggml_backend_tensor_get_async(backend, t_logits, sampled_logits_map[idx].data(), 0, ggml_nbytes(t_logits)); + } + + // extract sampled probabilities + for (const auto & [seq_id, t_probs] : res->t_sampled_probs) { + auto idx_it = seq_to_idx.find(seq_id); + GGML_ASSERT(idx_it != seq_to_idx.end()); + const int32_t idx = idx_it->second; + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched.get(), t_probs); + sampled_probs_map[idx].resize(ggml_nelements(t_probs)); + ggml_backend_tensor_get_async(backend, t_probs, sampled_probs_map[idx].data(), 0, ggml_nbytes(t_probs)); + } + } + + backend_has_sampled = !res->t_sampled_tokens.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty(); + + if (!backend_has_sampled) { + auto * t_logits = res->get_logits(); + auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; + + if (t_embd && res->get_embd_pooled()) { + t_embd = res->get_embd_pooled(); + } + + // extract logits + if (t_logits && n_outputs > 0) { + ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); + GGML_ASSERT(backend_res != nullptr); + GGML_ASSERT(logits != nullptr); + + float * logits_out = logits + n_outputs_prev*n_vocab; + + if (n_outputs) { + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size); + ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); + } + } + + // extract embeddings + if (t_embd && n_outputs > 0) { + ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); + GGML_ASSERT(backend_embd != nullptr); + + switch (cparams.pooling_type) { + case LLAMA_POOLING_TYPE_NONE: + { + // extract token embeddings + GGML_ASSERT(embd != nullptr); + float * embd_out = embd + n_outputs_prev*n_embd; + + if (n_outputs) { + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_MEAN: + case LLAMA_POOLING_TYPE_CLS: + case LLAMA_POOLING_TYPE_LAST: + { + // extract sequence embeddings (cleared before processing each batch) + auto & embd_seq_out = embd_seq; + + for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { + const llama_seq_id seq_id = ubatch.seq_id_unq[s]; + const int32_t seq_idx = ubatch.seq_idx[seq_id]; + + embd_seq_out[seq_id].resize(n_embd); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_RANK: + { + // extract the rerank score - n_cls_out floats per sequence + auto & embd_seq_out = embd_seq; + + const uint32_t n_cls_out = hparams.n_cls_out; + + for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { + const llama_seq_id seq_id = ubatch.seq_id_unq[s]; + const int32_t seq_idx = ubatch.seq_idx[seq_id]; + + embd_seq_out[seq_id].resize(n_cls_out); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_UNSPECIFIED: + { + GGML_ABORT("unknown pooling type"); } - } break; - case LLAMA_POOLING_TYPE_MEAN: - case LLAMA_POOLING_TYPE_CLS: - case LLAMA_POOLING_TYPE_LAST: - { - // extract sequence embeddings (cleared before processing each batch) - auto & embd_seq_out = embd_seq; - - for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { - const llama_seq_id seq_id = ubatch.seq_id_unq[s]; - const int32_t seq_idx = ubatch.seq_idx[seq_id]; - - embd_seq_out[seq_id].resize(n_embd); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); - } - } break; - case LLAMA_POOLING_TYPE_RANK: - { - // extract the rerank score - n_cls_out floats per sequence - auto & embd_seq_out = embd_seq; - - const uint32_t n_cls_out = hparams.n_cls_out; - - for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { - const llama_seq_id seq_id = ubatch.seq_id_unq[s]; - const int32_t seq_idx = ubatch.seq_idx[seq_id]; - - embd_seq_out[seq_id].resize(n_cls_out); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float)); - } - } break; - case LLAMA_POOLING_TYPE_UNSPECIFIED: - { - GGML_ABORT("unknown pooling type"); - } + } } } @@ -1231,7 +1431,7 @@ int llama_context::decode(const llama_batch & batch_inp) { n_outputs = n_outputs_all; // set output mappings - if (n_outputs > 0) { + if (n_outputs > 0 && !backend_has_sampled) { bool sorted_output = true; auto & out_ids = balloc->get_out_ids(); @@ -1345,9 +1545,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { } float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get()); + llama_token * s_output_base = (llama_token *) ggml_backend_buffer_get_base(buf_output.get()); - logits = has_logits ? output_base : nullptr; - embd = has_embd ? output_base + logits_size : nullptr; + logits = has_logits ? output_base : nullptr; + embd = has_embd ? output_base + logits_size : nullptr; + sampled_tokens = !samplers.empty() ? s_output_base : nullptr; + sampled_probs = !samplers.empty() ? embd : nullptr; // set all ids as invalid (negative) std::fill(output_ids.begin(), output_ids.end(), -1); @@ -1456,6 +1659,7 @@ llm_graph_params llama_context::graph_params( /*.loras =*/ &loras, /*.mctx =*/ mctx, /*.cross =*/ &cross, + /*.samplers =*/ samplers, /*.n_outputs =*/ n_outputs, /*.cb =*/ graph_get_cb(), /*.res =*/ res, @@ -2319,6 +2523,8 @@ llama_context_params llama_context_default_params() { /*.op_offload =*/ true, /*.swa_full =*/ true, /*.kv_unified =*/ false, + /*.sampler =*/ nullptr, + /*.n_sampler =*/ 0, }; return result; @@ -2478,6 +2684,13 @@ float * llama_get_logits(llama_context * ctx) { float * llama_get_logits_ith(llama_context * ctx, int32_t i) { ctx->synchronize(); + if (ctx->get_backend_sampled_token_ith(i) != LLAMA_TOKEN_NULL) { + return nullptr; + } + if (ctx->get_backend_sampled_probs_ith(i) != nullptr) { + return nullptr; + } + return ctx->get_logits_ith(i); } @@ -2499,6 +2712,46 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { return ctx->get_embeddings_seq(seq_id); } +void llama_set_backend_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * sampler) { + ctx->set_backend_sampler(seq_id, sampler); +} + +llama_token llama_get_backend_sampled_token_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return ctx->get_backend_sampled_token_ith(i); +} + +float * llama_get_backend_sampled_probs_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return ctx->get_backend_sampled_probs_ith(i); +} + +float * llama_get_backend_sampled_logits_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return ctx->get_backend_sampled_logits_ith(i); +} + +llama_token * llama_get_backend_sampled_token_ids_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return const_cast(ctx->get_backend_sampled_token_ids_ith(i)); +} + +uint32_t llama_get_backend_sampled_logits_count_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return static_cast(ctx->get_backend_sampled_logits_count(i)); +} + +uint32_t llama_get_backend_sampled_probs_count_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return static_cast(ctx->get_backend_sampled_probs_count(i)); +} + // llama adapter API int32_t llama_set_adapter_lora( diff --git a/src/llama-context.h b/src/llama-context.h index 20cbd78955..b9020beff1 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -66,6 +66,16 @@ struct llama_context { float * get_embeddings_ith(int32_t i); float * get_embeddings_seq(llama_seq_id seq_id); + llama_token * get_backend_sampled_tokens(); + llama_token get_backend_sampled_token_ith(int32_t idx); + + float * get_backend_sampled_logits_ith(int32_t idx); + const llama_token * get_backend_sampled_token_ids_ith(int32_t idx); + size_t get_backend_sampled_logits_count(int32_t idx) const; + + float * get_backend_sampled_probs_ith(int32_t idx); + size_t get_backend_sampled_probs_count(int32_t idx) const; + void attach_threadpool( ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch); @@ -208,6 +218,8 @@ public: // reserve a graph with a dummy ubatch of the specified size ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false); + void set_backend_sampler(llama_seq_id seq_id, llama_sampler * sampler); + private: llm_graph_params graph_params( llm_graph_result * res, @@ -242,6 +254,16 @@ private: size_t logits_size = 0; // capacity (of floats) for logits float * logits = nullptr; + std::unordered_map samplers; + llama_token * sampled_tokens = nullptr; + std::unordered_map sampled_tokens_map; + + float * sampled_probs = nullptr; + std::unordered_map> sampled_probs_map; + + std::unordered_map> sampled_logits_map; + std::unordered_map> sampled_token_ids_map; + // embeddings output (2-dimensional array: [n_outputs][n_embd]) // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE size_t embd_size = 0; // capacity (of floats) for embeddings diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 650e40ec6f..49aab37f33 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -3,6 +3,7 @@ #include "llama-impl.h" #include "llama-batch.h" #include "llama-cparams.h" +#include "llama-model.h" #include "llama-kv-cache.h" #include "llama-kv-cache-iswa.h" @@ -462,6 +463,28 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { inp_rs->set_input(ubatch); } +void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) { + GGML_UNUSED(ubatch); + for (const auto & [seq_id, sampler] : samplers) { + if (sampler->iface->set_input_ggml) { + sampler->iface->set_input_ggml(sampler); + } + } +} + +bool llm_graph_input_sampling::can_reuse(const llm_graph_params & params) { + if (params.samplers.empty()) { + return true; + } + + for (const auto & [seq_id, sampler] : params.samplers) { + if (sampler_versions[seq_id] != llama_sampler_chain_get_version(sampler)) { + return false; + } + } + return true; +} + // // llm_graph_result // @@ -482,6 +505,10 @@ void llm_graph_result::reset() { t_logits = nullptr; t_embd = nullptr; t_embd_pooled = nullptr; + t_sampled_tokens.clear(); + t_sampled_probs.clear(); + t_sampled_logits.clear(); + t_sampled_token_ids.clear(); params = {}; @@ -587,6 +614,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : loras (params.loras), mctx (params.mctx), cross (params.cross), + samplers (params.samplers), cb_func (params.cb), res (params.res), ctx0 (res->get_ctx()), @@ -2021,6 +2049,103 @@ void llm_graph_context::build_pooling( ggml_build_forward_expand(gf, cur); } +void llm_graph_context::build_sampling(const llama_model & model, const llm_graph_params & params) const { + GGML_UNUSED(params); + if (samplers.empty()) { + return; + } + + std::unordered_map seq_to_logit_row; + int32_t logit_row_idx = 0; + + for (uint32_t i = 0; i < ubatch.n_tokens; i++) { + if (ubatch.output[i]) { + llama_seq_id seq_id = ubatch.seq_id[i][0]; + seq_to_logit_row[seq_id] = logit_row_idx; + logit_row_idx++; + } + } + if (seq_to_logit_row.empty()) { + return; + } + + // res->t_logits will contain logits for all tokens that specied that want + // logits calculated (logits=1 or output=1) + ggml_tensor * logits_t = res->t_logits; + GGML_ASSERT(res->t_logits != nullptr && "missing t_logits tensor"); + + const int n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(&model)); + GGML_ASSERT(logits_t->ne[0] == n_vocab); + + ggml_backend_dev_t device = model.dev_output(); + ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(device); + + std::unordered_map active_samplers; + + for (const auto & [seq_id, sampler] : samplers) { + // Only process samplers for sequences that are in the current batch + auto it = seq_to_logit_row.find(seq_id); + if (it == seq_to_logit_row.end()) { + continue; + } + const int32_t row_idx = it->second; + + // Allow GPU sampler to create input tensors by implementing init_ggml. + if (sampler->iface->init_ggml != nullptr) { + sampler->iface->init_ggml(sampler, buft); + } + + active_samplers[seq_id] = sampler; + + ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, n_vocab, row_idx * logits_t->nb[1]); + ggml_format_name(logits_seq, "logits_seq_%d", seq_id); + + struct llama_sampler_ggml_data ggml_data = { + /*.logits =*/ logits_seq, + /*.probs =*/ nullptr, + /*.sampled_token =*/ nullptr, + /*.filtered_ids =*/ nullptr, + }; + + llama_sampler_apply_ggml(sampler, ctx0, gf, &ggml_data); + + if (ggml_data.sampled_token != nullptr) { + res->t_sampled_tokens[seq_id] = ggml_data.sampled_token; + ggml_build_forward_expand(gf, ggml_data.sampled_token); + } + + if (ggml_data.probs != nullptr) { + res->t_sampled_probs[seq_id] = ggml_data.probs; + ggml_build_forward_expand(gf, ggml_data.probs); + } + + if (ggml_data.logits != logits_seq) { + res->t_sampled_logits[seq_id] = ggml_data.logits; + ggml_build_forward_expand(gf, res->t_sampled_logits[seq_id]); + } + + if (ggml_data.filtered_ids != nullptr) { + res->t_sampled_token_ids[seq_id] = ggml_data.filtered_ids; + ggml_build_forward_expand(gf, ggml_data.filtered_ids); + } + } + + // TODO: Call llama_sampler_accept_ggml after all samplers have been applied. + /* + for (const auto & [seq_id, sampler] : samplers) { + if (auto it = res->t_sampled_tokens.find(seq_id); it != res->t_sampled_tokens.end()) { + ggml_tensor * selected_token = it->second; + if (selected_token != nullptr) { + llama_sampler_accept_ggml(sampler, ctx0, gf, selected_token); + } + } + } + */ + + auto inp_sampling = std::make_unique(n_vocab, false, active_samplers); + res->add_input(std::move(inp_sampling)); +} + int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) { // TODO move to hparams if a T5 variant appears that uses a different value const int64_t max_distance = 128; diff --git a/src/llama-graph.h b/src/llama-graph.h index d0c3934f67..bd176e5d38 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -383,6 +383,32 @@ public: const llama_memory_hybrid_context * mctx; }; +class llm_graph_input_sampling : public llm_graph_input_i { +public: + llm_graph_input_sampling(int32_t n_vocab, bool sorted, + std::unordered_map samplers) : + n_vocab(n_vocab), sorted_value(sorted), samplers(samplers) { + + sampler_versions.reserve(samplers.size()); + for (const auto & [seq_id, sampler] : samplers) { + sampler_versions[seq_id] = llama_sampler_chain_get_version(sampler); + } + } + virtual ~llm_graph_input_sampling() = default; + + void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + + int32_t n_vocab; + bool sorted_value; + ggml_tensor * size = nullptr; // I32 [1] + ggml_tensor * sorted = nullptr; // I32 [1] + + // Track sampler chain version for reuse + std::unordered_map sampler_versions; + std::unordered_map samplers; +}; + // // llm_graph_result // @@ -416,6 +442,23 @@ struct llm_graph_params { const llama_memory_context_i * mctx; const llama_cross * cross; + std::unordered_map samplers; + + static bool samplers_equal( + const std::unordered_map & lhs, + const std::unordered_map & rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + for (const auto & [seq_id, sampler] : lhs) { + auto it = rhs.find(seq_id); + if (it == rhs.end() || it->second != sampler) { + return false; + } + } + return true; + } + uint32_t n_outputs; llm_graph_cb cb; @@ -463,7 +506,9 @@ struct llm_graph_params { cvec == other.cvec && loras == other.loras && cross == other.cross && - n_outputs == other.n_outputs; + n_outputs == other.n_outputs && + samplers_equal(samplers, other.samplers); + } }; @@ -504,6 +549,11 @@ public: ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; + std::unordered_map t_sampled_logits; + std::unordered_map t_sampled_token_ids; + std::unordered_map t_sampled_tokens; + std::unordered_map t_sampled_probs; + std::vector inputs; ggml_context_ptr ctx_compute; @@ -579,6 +629,8 @@ struct llm_graph_context { const llama_memory_context_i * mctx; const llama_cross * cross; + std::unordered_map samplers; + const llm_graph_cb & cb_func; llm_graph_result * res; @@ -819,6 +871,12 @@ struct llm_graph_context { ggml_tensor * cls_out, ggml_tensor * cls_out_b) const; + // + // sampling (backend sampling) + // + + void build_sampling(const llama_model & model, const llm_graph_params & params) const; + // // dense (out) // diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e703181a19..ca75ce4c9e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -7412,6 +7412,9 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { // add on pooling layer llm->build_pooling(cls, cls_b, cls_out, cls_out_b); + // add backend sampling layers (if any) + llm->build_sampling(*this, params); + // if the gguf model was converted with --sentence-transformers-dense-modules // there will be two additional dense projection layers // dense linear projections are applied after pooling diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index adb3f8810e..dc9227c1a5 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -372,6 +372,39 @@ void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_ar smpl->iface->apply(smpl, cur_p); } +void llama_sampler_apply_ggml( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_ggml_data * ggml_data) { + GGML_ASSERT(smpl->iface->apply_ggml); + smpl->iface->apply_ggml(smpl, ctx, gf, ggml_data); +} + +void llama_sampler_accept_ggml( + struct llama_sampler * smpl, + ggml_context * ctx, + ggml_cgraph * gf, + struct ggml_tensor * selected_token) { + if (smpl->iface->accept_ggml) { + smpl->iface->accept_ggml(smpl, ctx, gf, selected_token); + } +} + +void llama_sampler_set_input_ggml(struct llama_sampler * smpl) { + if (smpl->iface->set_input_ggml) { + smpl->iface->set_input_ggml(smpl); + } +} + +void llama_sampler_init_ggml( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + if (smpl->iface->init_ggml) { + smpl->iface->init_ggml(smpl, buft); + } +} + void llama_sampler_reset(struct llama_sampler * smpl) { if (smpl->iface->reset) { smpl->iface->reset(smpl); @@ -406,7 +439,15 @@ void llama_sampler_free(struct llama_sampler * smpl) { } llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) { - const auto * logits = llama_get_logits_ith(ctx, idx); + const llama_token sampled_token = llama_get_backend_sampled_token_ith(ctx, idx); + const float * sampled_probs = llama_get_backend_sampled_probs_ith(ctx, idx); + const float * sampled_logits = llama_get_backend_sampled_logits_ith(ctx, idx); + const llama_token * sampled_ids = llama_get_backend_sampled_token_ids_ith(ctx, idx); + + // If a backend sampler has already sampled a token, return it. + if (sampled_token != LLAMA_TOKEN_NULL) { + return sampled_token; + } const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); @@ -415,9 +456,40 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte // TODO: do not allocate each time std::vector cur; - cur.reserve(n_vocab); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + + if (sampled_probs) { + const uint32_t sampled_probs_count = llama_get_backend_sampled_probs_count_ith(ctx, idx); + cur.reserve(sampled_probs_count); + // The backend sampler has filtered the probabilities so we need to use the sampled ids. + if (sampled_ids != nullptr) { + for (uint32_t i = 0; i < sampled_probs_count; ++i) { + cur.emplace_back(llama_token_data{sampled_ids[i], 0.0f, sampled_probs[i]}); + } + } else { + for (llama_token token_id = 0; token_id < (int) sampled_probs_count; token_id++) { + cur.emplace_back(llama_token_data{token_id, 0.0f, sampled_probs[token_id]}); + } + } + } else if (sampled_logits) { + const uint32_t sampled_logits_count = llama_get_backend_sampled_logits_count_ith(ctx, idx); + cur.reserve(sampled_logits_count); + // The backend sampler has filtered the logits so we need to use the sampled ids. + if (sampled_ids != nullptr) { + for (llama_token i = 0; i < (int)sampled_logits_count; i++) { + cur.emplace_back(llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f}); + } + } else { + for (llama_token token_id = 0; token_id < (int)sampled_logits_count; token_id++) { + cur.emplace_back(llama_token_data{token_id, sampled_logits[token_id], 0.0f}); + } + } + } else { + const auto * logits = llama_get_logits_ith(ctx, idx); + GGML_ASSERT(logits != nullptr); + cur.reserve(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + } } llama_token_data_array cur_p = { @@ -462,6 +534,10 @@ static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_d time_meas tm(chain->t_sample_us, chain->params.no_perf); for (auto * smpl : chain->samplers) { + // Skip GPU samplers - they have apply_ggml but no apply + if (smpl->iface->apply == nullptr) { + continue; + } llama_sampler_apply(smpl, cur_p); } } @@ -499,13 +575,67 @@ static void llama_sampler_chain_free(struct llama_sampler * smpl) { delete chain; } +static void llama_sampler_chain_apply_ggml( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_ggml_data * ggml_data) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto * smpl : chain->samplers) { + if (smpl->iface->apply_ggml) { + smpl->iface->apply_ggml(smpl, ctx, gf, ggml_data); + } + } +} + +static void llama_sampler_chain_accept_ggml( + struct llama_sampler * smpl, + ggml_context * ctx, + ggml_cgraph * gf, + struct ggml_tensor * selected_token) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto * smpl : chain->samplers) { + if (smpl->iface->accept_ggml) { + smpl->iface->accept_ggml(smpl, ctx, gf, selected_token); + } + } +} + +static void llama_sampler_chain_set_input_ggml(struct llama_sampler * smpl) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto * smpl : chain->samplers) { + if (smpl->iface->set_input_ggml) { + smpl->iface->set_input_ggml(smpl); + } + } +} + +static void llama_sampler_chain_set_backend_context( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto * smpl : chain->samplers) { + if (smpl->iface->init_ggml) { + smpl->iface->init_ggml(smpl,buft); + } + } +} + static struct llama_sampler_i llama_sampler_chain_i = { - /* .name = */ llama_sampler_chain_name, - /* .accept = */ llama_sampler_chain_accept, - /* .apply = */ llama_sampler_chain_apply, - /* .reset = */ llama_sampler_chain_reset, - /* .clone = */ llama_sampler_chain_clone, - /* .free = */ llama_sampler_chain_free, + /* .name = */ llama_sampler_chain_name, + /* .accept = */ llama_sampler_chain_accept, + /* .apply = */ llama_sampler_chain_apply, + /* .reset = */ llama_sampler_chain_reset, + /* .clone = */ llama_sampler_chain_clone, + /* .free = */ llama_sampler_chain_free, + /* .apply_ggml = */ llama_sampler_chain_apply_ggml, + /* .accept_ggml = */ llama_sampler_chain_accept_ggml, + /* .set_input_ggml = */ llama_sampler_chain_set_input_ggml, + /* .set_backend_context = */ llama_sampler_chain_set_backend_context, }; struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) { @@ -523,6 +653,7 @@ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_param void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) { auto * p = (llama_sampler_chain *) chain->ctx; p->samplers.push_back(smpl); + p->version++; } struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) { @@ -544,6 +675,7 @@ struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, auto * result = p->samplers[i]; p->samplers.erase(p->samplers.begin() + i); + p->version++; return result; } @@ -554,6 +686,11 @@ int llama_sampler_chain_n(const struct llama_sampler * chain) { return p->samplers.size(); } +uint64_t llama_sampler_chain_get_version(const struct llama_sampler * chain) { + const auto * p = (const llama_sampler_chain *) chain->ctx; + return p->version; +} + // // samplers // @@ -574,12 +711,16 @@ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_to } static struct llama_sampler_i llama_sampler_greedy_i = { - /* .name = */ llama_sampler_greedy_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_greedy_apply, - /* .reset = */ nullptr, - /* .clone = */ nullptr, - /* .free = */ nullptr, + /* .name = */ llama_sampler_greedy_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_greedy_apply, + /* .reset = */ nullptr, + /* .clone = */ nullptr, + /* .free = */ nullptr, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .set_backend_context = */ nullptr, }; struct llama_sampler * llama_sampler_init_greedy() { @@ -699,12 +840,16 @@ static void llama_sampler_dist_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_dist_i = { - /* .name = */ llama_sampler_dist_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_dist_apply, - /* .reset = */ llama_sampler_dist_reset, - /* .clone = */ llama_sampler_dist_clone, - /* .free = */ llama_sampler_dist_free, + /* .name = */ llama_sampler_dist_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_dist_apply, + /* .reset = */ llama_sampler_dist_reset, + /* .clone = */ llama_sampler_dist_clone, + /* .free = */ llama_sampler_dist_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .set_backend_context = */ nullptr, }; struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { @@ -744,12 +889,16 @@ static void llama_sampler_top_k_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_top_k_i = { - /* .name = */ llama_sampler_top_k_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_top_k_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_top_k_clone, - /* .free = */ llama_sampler_top_k_free, + /* .name = */ llama_sampler_top_k_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_k_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_k_clone, + /* .free = */ llama_sampler_top_k_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .set_backend_context = */ nullptr, }; struct llama_sampler * llama_sampler_init_top_k(int32_t k) { @@ -839,12 +988,16 @@ static void llama_sampler_top_p_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_top_p_i = { - /* .name = */ llama_sampler_top_p_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_top_p_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_top_p_clone, - /* .free = */ llama_sampler_top_p_free, + /* .name = */ llama_sampler_top_p_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_p_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_p_clone, + /* .free = */ llama_sampler_top_p_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .set_backend_context = */ nullptr, }; struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) { @@ -933,12 +1086,16 @@ static void llama_sampler_min_p_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_min_p_i = { - /* .name = */ llama_sampler_min_p_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_min_p_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_min_p_clone, - /* .free = */ llama_sampler_min_p_free, + /* .name = */ llama_sampler_min_p_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_min_p_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_min_p_clone, + /* .free = */ llama_sampler_min_p_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .set_backend_context = */ nullptr, }; struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) { @@ -1032,12 +1189,16 @@ static void llama_sampler_typical_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_typical_i = { - /* .name = */ llama_sampler_typical_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_typical_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_typical_clone, - /* .free = */ llama_sampler_typical_free, + /* .name = */ llama_sampler_typical_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_typical_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_typical_clone, + /* .free = */ llama_sampler_typical_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .set_backend_context = */ nullptr, }; struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) { @@ -1076,12 +1237,16 @@ static void llama_sampler_temp_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_temp_i = { - /* .name = */ llama_sampler_temp_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_temp_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_temp_clone, - /* .free = */ llama_sampler_temp_free, + /* .name = */ llama_sampler_temp_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_temp_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_temp_clone, + /* .free = */ llama_sampler_temp_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .set_backend_context = */ nullptr, }; struct llama_sampler * llama_sampler_init_temp(float temp) { @@ -1186,12 +1351,16 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_temp_ext_i = { - /* .name = */ llama_sampler_temp_ext_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_temp_ext_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_temp_ext_clone, - /* .free = */ llama_sampler_temp_ext_free, + /* .name = */ llama_sampler_temp_ext_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_temp_ext_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_temp_ext_clone, + /* .free = */ llama_sampler_temp_ext_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .set_backend_context = */ nullptr, }; struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) { @@ -1280,12 +1449,16 @@ static void llama_sampler_xtc_reset(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_xtc_i = { - /* .name = */ llama_sampler_xtc_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sample_xtc_apply, - /* .reset = */ llama_sampler_xtc_reset, - /* .clone = */ llama_sampler_xtc_clone, - /* .free = */ llama_sampler_xtc_free, + /* .name = */ llama_sampler_xtc_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sample_xtc_apply, + /* .reset = */ llama_sampler_xtc_reset, + /* .clone = */ llama_sampler_xtc_clone, + /* .free = */ llama_sampler_xtc_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .set_backend_context = */ nullptr, }; struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) { @@ -1388,12 +1561,16 @@ static void llama_sampler_mirostat_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_mirostat_i = { - /* .name = */ llama_sampler_mirostat_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_mirostat_apply, - /* .reset = */ llama_sampler_mirostat_reset, - /* .clone = */ llama_sampler_mirostat_clone, - /* .free = */ llama_sampler_mirostat_free, + /* .name = */ llama_sampler_mirostat_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_mirostat_apply, + /* .reset = */ llama_sampler_mirostat_reset, + /* .clone = */ llama_sampler_mirostat_clone, + /* .free = */ llama_sampler_mirostat_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .set_backend_context = */ nullptr, }; struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) { @@ -1487,12 +1664,16 @@ static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_mirostat_v2_i = { - /* .name = */ llama_sampler_mirostat_v2_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_mirostat_v2_apply, - /* .reset = */ llama_sampler_mirostat_v2_reset, - /* .clone = */ llama_sampler_mirostat_v2_clone, - /* .free = */ llama_sampler_mirostat_v2_free, + /* .name = */ llama_sampler_mirostat_v2_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_mirostat_v2_apply, + /* .reset = */ llama_sampler_mirostat_v2_reset, + /* .clone = */ llama_sampler_mirostat_v2_clone, + /* .free = */ llama_sampler_mirostat_v2_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .set_backend_context = */ nullptr, }; struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) { @@ -1604,12 +1785,16 @@ static void llama_sampler_grammar_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_grammar_i = { - /* .name = */ llama_sampler_grammar_name, - /* .accept = */ llama_sampler_grammar_accept_impl, - /* .apply = */ llama_sampler_grammar_apply, - /* .reset = */ llama_sampler_grammar_reset, - /* .clone = */ llama_sampler_grammar_clone, - /* .free = */ llama_sampler_grammar_free, + /* .name = */ llama_sampler_grammar_name, + /* .accept = */ llama_sampler_grammar_accept_impl, + /* .apply = */ llama_sampler_grammar_apply, + /* .reset = */ llama_sampler_grammar_reset, + /* .clone = */ llama_sampler_grammar_clone, + /* .free = */ llama_sampler_grammar_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .set_backend_context = */ nullptr, }; static struct llama_sampler * llama_sampler_init_grammar_impl( @@ -1811,12 +1996,16 @@ static void llama_sampler_penalties_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_penalties_i = { - /* .name = */ llama_sampler_penalties_name, - /* .accept = */ llama_sampler_penalties_accept, - /* .apply = */ llama_sampler_penalties_apply, - /* .reset = */ llama_sampler_penalties_reset, - /* .clone = */ llama_sampler_penalties_clone, - /* .free = */ llama_sampler_penalties_free, + /* .name = */ llama_sampler_penalties_name, + /* .accept = */ llama_sampler_penalties_accept, + /* .apply = */ llama_sampler_penalties_apply, + /* .reset = */ llama_sampler_penalties_reset, + /* .clone = */ llama_sampler_penalties_clone, + /* .free = */ llama_sampler_penalties_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .set_backend_context = */ nullptr, }; struct llama_sampler * llama_sampler_init_penalties( @@ -1902,12 +2091,16 @@ static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_top_n_sigma_i = { - /* .name = */ llama_sampler_top_n_sigma_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_top_n_sigma_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_top_n_sigma_clone, - /* .free = */ llama_sampler_top_n_sigma_free, + /* .name = */ llama_sampler_top_n_sigma_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_n_sigma_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_n_sigma_clone, + /* .free = */ llama_sampler_top_n_sigma_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .set_backend_context = */ nullptr, }; struct llama_sampler * llama_sampler_init_top_n_sigma(float n) { @@ -2232,12 +2425,16 @@ static void llama_sampler_dry_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_dry_i = { - /* .name = */ llama_sampler_dry_name, - /* .accept = */ llama_sampler_dry_accept, - /* .apply = */ llama_sampler_dry_apply, - /* .reset = */ llama_sampler_dry_reset, - /* .clone = */ llama_sampler_dry_clone, - /* .free = */ llama_sampler_dry_free, + /* .name = */ llama_sampler_dry_name, + /* .accept = */ llama_sampler_dry_accept, + /* .apply = */ llama_sampler_dry_apply, + /* .reset = */ llama_sampler_dry_reset, + /* .clone = */ llama_sampler_dry_clone, + /* .free = */ llama_sampler_dry_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .set_backend_context = */ nullptr, }; struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t n_ctx_train, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) { @@ -2373,12 +2570,16 @@ static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_logit_bias_i = { - /* .name = */ llama_sampler_logit_bias_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_logit_bias_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_logit_bias_clone, - /* .free = */ llama_sampler_logit_bias_free, + /* .name = */ llama_sampler_logit_bias_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_logit_bias_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_logit_bias_clone, + /* .free = */ llama_sampler_logit_bias_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .set_backend_context = */ nullptr, }; struct llama_sampler * llama_sampler_init_logit_bias( @@ -2603,12 +2804,16 @@ static void llama_sampler_infill_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_infill_i = { - /* .name = */ llama_sampler_infill_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_infill_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_infill_clone, - /* .free = */ llama_sampler_infill_free, + /* .name = */ llama_sampler_infill_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_infill_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_infill_clone, + /* .free = */ llama_sampler_infill_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .set_backend_context = */ nullptr, }; struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) { diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 759dd7dcb7..d92311f58a 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -21,6 +21,9 @@ struct llama_sampler_chain { mutable int64_t t_sample_us; mutable int32_t n_sample; + + // simple version tracking for GPU sampling graph can_reuse + uint64_t version = 0; }; struct llama_sampler * llama_sampler_init_dry_testing( diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index d9cc5e933f..0db8b4bd88 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -206,6 +206,18 @@ llama_build_and_test(test-backend-ops.cpp) llama_build_and_test(test-model-load-cancel.cpp LABEL "model") llama_build_and_test(test-autorelease.cpp LABEL "model") +llama_build_and_test(test-backend-sampler.cpp LABEL "model") +target_include_directories(test-backend-sampler PRIVATE ${PROJECT_SOURCE_DIR}/src) +llama_test(test-backend-sampler NAME test-backend-sampler-greedy ARGS --test greedy) +llama_test(test-backend-sampler NAME test-backend-sampler-temp ARGS --test temp) +llama_test(test-backend-sampler NAME test-backend-sampler-top_k ARGS --test top_k) +llama_test(test-backend-sampler NAME test-backend-sampler-dist ARGS --test dist) +llama_test(test-backend-sampler NAME test-backend-sampler-dist-and-cpu ARGS --test dist_and_cpu) +llama_test(test-backend-sampler NAME test-backend-sampler-logit-bias ARGS --test logit_bias) +llama_test(test-backend-sampler NAME test-backend-sampler-mul_seq ARGS --test multi_sequence) +llama_test(test-backend-sampler NAME test-backend-sampler-set-sampler ARGS --test set_sampler) + + if (NOT GGML_BACKEND_DL) # these tests use the backends directly and cannot be built with dynamic loading llama_build_and_test(test-barrier.cpp) diff --git a/tests/test-backend-sampler.cpp b/tests/test-backend-sampler.cpp new file mode 100644 index 0000000000..191eebba3e --- /dev/null +++ b/tests/test-backend-sampler.cpp @@ -0,0 +1,760 @@ +#include "ggml.h" +#include "llama.h" +#include "get-model.h" +#include "common.h" + +#ifdef NDEBUG +#undef NDEBUG +#endif + +#include +#include +#include +#include +#include +#include +#include + +struct test_model_context { + llama_model * model = nullptr; + llama_context * ctx = nullptr; + const llama_vocab * vocab = nullptr; + int n_vocab = 0; + std::unordered_map seq_positions; + std::unordered_map last_batch_info; + + bool setup_model(const char * model_path) { + if (model != nullptr) { + return true; + } + + llama_backend_init(); + + llama_model_params mparams = llama_model_default_params(); + model = llama_model_load_from_file(model_path, mparams); + if (model == nullptr) { + fprintf(stderr, "Warning: failed to load model '%s', skipping test\n", model_path); + cleanup(); + return false; + } + vocab = llama_model_get_vocab(model); + + return true; + } + + bool setup(const char * model_path, std::vector & configs) { + if (model == nullptr) { + setup_model(model_path); + } + + if (model != nullptr && ctx != nullptr) { + return true; + } + + llama_context_params cparams = llama_context_default_params(); + cparams.n_ctx = 512; + cparams.n_batch = 512; + cparams.samplers = configs.data(); + cparams.n_samplers = configs.size(); + + int32_t max_seq_id = 0; + for (const auto & config : configs) { + if (config.seq_id > max_seq_id) { + max_seq_id = config.seq_id; + } + } + cparams.n_seq_max = max_seq_id + 1; + + ctx = llama_init_from_model(model, cparams); + if (ctx == nullptr) { + fprintf(stderr, "Warning: failed to create context, skipping test\n"); + cleanup(); + return false; + } + llama_set_warmup(ctx, false); + + vocab = llama_model_get_vocab(model); + n_vocab = llama_vocab_n_tokens(vocab); + fprintf(stderr, "Vocabulary size: %d\n", n_vocab); + + return true; + } + + bool decode(const std::map & prompts) { + if (ctx == nullptr || vocab == nullptr) { + fprintf(stderr, "Error: context not initialized, call setup() first\n"); + return false; + } + + last_batch_info.clear(); + llama_batch batch = llama_batch_init(512, 0, prompts.size()); + + int n_tokens_per_prompt = 0; + + for (const auto & [seq_id, prompt] : prompts) { + std::vector tokens; + tokens.push_back(llama_vocab_bos(vocab)); + + std::vector prompt_tokens(32); + int n_tokens = llama_tokenize(vocab, prompt.c_str(), prompt.length(), + prompt_tokens.data(), prompt_tokens.size(), + false, false); + //TODO: refactor this function to just handle a single prompt at a time + // to avoid this check and complexity. + if (n_tokens_per_prompt == 0) { + n_tokens_per_prompt = n_tokens; + } else { + if (n_tokens != n_tokens_per_prompt) { + fprintf(stderr, "Error: prompts must have the same number of tokens\n"); + llama_batch_free(batch); + return false; + } + n_tokens_per_prompt = n_tokens; + } + if (n_tokens < 0) { + fprintf(stderr, "Warning: tokenization failed for seq_id %d\n", seq_id); + llama_batch_free(batch); + return false; + } + + for (int i = 0; i < n_tokens; i++) { + tokens.push_back(prompt_tokens[i]); + } + + for (size_t i = 0; i < tokens.size(); i++) { + common_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1); + } + + seq_positions[seq_id] = tokens.size(); + } + + + printf("Batch contents:\n"); + printf(" n_tokens: %d\n", batch.n_tokens); + for (int i = 0; i < batch.n_tokens; i++) { + printf(" token[%d]: tok=%-5d, pos=%d, n_seq_id=%d, seq_ids=[", i, batch.token[i], batch.pos[i], batch.n_seq_id[i]); + + for (int j = 0; j < batch.n_seq_id[i]; j++) { + printf("%d%s", batch.seq_id[i][j], j < batch.n_seq_id[i]-1 ? ", " : ""); + } + printf("], logits=%d\n", batch.logits[i]); +} + + if (llama_decode(ctx, batch) != 0) { + fprintf(stderr, "Warning: llama_decode failed\n"); + llama_batch_free(batch); + return false; + } + + // Build mapping from seq id to batch token idx + for (int i = 0; i < batch.n_tokens; i++) { + if (batch.logits[i]) { + llama_seq_id seq_id = batch.seq_id[i][0]; + last_batch_info[seq_id] = i; + printf("seq %d : batch idx %d\n", seq_id, i); + } + } + + llama_batch_free(batch); + return true; + } + + int32_t idx_for_seq(llama_seq_id seq_id) { + auto it = last_batch_info.find(seq_id); + if (it == last_batch_info.end()) { + fprintf(stderr, "Error: no batch index found for seq_id %d\n", seq_id); + return -1; + } + return it->second; + } + + bool decode_token(llama_token token, llama_seq_id seq_id = 0) { + if (ctx == nullptr) { + fprintf(stderr, "Error: context not initialized, call setup() first\n"); + return false; + } + + llama_batch batch = llama_batch_init(1, 0, 1); + int32_t pos = seq_positions[seq_id]; + common_batch_add(batch, token, pos, { seq_id }, true); + + if (llama_decode(ctx, batch) != 0) { + fprintf(stderr, "Warning: llama_decode failed for token %d in seq %d\n", token, seq_id); + llama_batch_free(batch); + return false; + } + + last_batch_info.clear(); + for (int i = 0; i < batch.n_tokens; i++) { + if (batch.logits[i]) { + llama_seq_id cur_seq = batch.seq_id[i][0]; + last_batch_info[cur_seq] = i; + } + } + + seq_positions[seq_id]++; + llama_batch_free(batch); + return true; + } + + bool decode_tokens(const std::map & seq_tokens) { + if (ctx == nullptr) { + fprintf(stderr, "Error: context not initialized, call setup() first\n"); + return false; + } + + llama_batch batch = llama_batch_init(seq_tokens.size(), 0, seq_tokens.size()); + + for (const auto & [seq_id, token] : seq_tokens) { + int32_t pos = seq_positions[seq_id]; + common_batch_add(batch, token, pos, { seq_id }, true); + } + + if (llama_decode(ctx, batch) != 0) { + fprintf(stderr, "Warning: llama_decode failed for batch tokens\n"); + llama_batch_free(batch); + return false; + } + + for (const auto & [seq_id, _] : seq_tokens) { + seq_positions[seq_id]++; + } + + last_batch_info.clear(); + for (int i = 0; i < batch.n_tokens; i++) { + if (batch.logits[i]) { + llama_seq_id cur_seq = batch.seq_id[i][0]; + last_batch_info[cur_seq] = i; + } + } + + llama_batch_free(batch); + return true; + } + + std::string token_to_piece(llama_token token, bool special) { + std::string piece; + piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n' + const int n_chars = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special); + if (n_chars < 0) { + piece.resize(-n_chars); + int check = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special); + GGML_ASSERT(check == -n_chars); + } + else { + piece.resize(n_chars); + } + + return piece; + } + + void cleanup() { + if (ctx) llama_free(ctx); + if (model) llama_model_free(model); + llama_backend_free(); + ctx = nullptr; + model = nullptr; + vocab = nullptr; + } + + ~test_model_context() { + cleanup(); + } +}; + +static void test_backend_greedy_sampling(const char * model_path) { + test_model_context test_ctx; + + const int seq_id = 0; + + struct llama_sampler_chain_params backend_sampler_params = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_sampler_params); + + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_greedy()); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Some"}})) { + return; + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + + llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx); + printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + + token = llama_get_backend_sampled_token_ith(test_ctx.ctx, -1); + printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + + for (int i = 0; i < 10; i++) { + int32_t loop_idx = test_ctx.idx_for_seq(seq_id); + llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, loop_idx); + printf("Generation step %d: token id:%d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str()); + test_ctx.decode_token(token, 0); + } +} + +static void test_backend_top_k_sampling(const char * model_path) { + test_model_context test_ctx; + + const int seq_id = 0; + const int32_t k = 8; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_top_k(k)); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Hello"}})) { + return; + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + + float * logits = llama_get_backend_sampled_logits_ith(test_ctx.ctx, batch_idx); + uint32_t n_logits = llama_get_backend_sampled_logits_count_ith(test_ctx.ctx, batch_idx); + for (size_t i = 0; i < n_logits; ++i) { + printf("top_k logit[%zu] = %.6f\n", i, logits[i]); + } + + // Sample using CPU sampler for verification that it is possible to do hybrid + // sampling, first top_k on the backend and then dist on the CPU. + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * chain = llama_sampler_chain_init(chain_params); + GGML_ASSERT(chain->iface->apply_ggml != nullptr); + + llama_sampler_chain_add(chain, llama_sampler_init_dist(18)); + llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + + printf("backend top-k hybrid sampling test PASSED\n"); + + llama_sampler_free(chain); +} + +static void test_backend_temp_sampling(const char * model_path) { + test_model_context test_ctx; + + const float temp_0 = 0.8f; + struct llama_sampler_chain_params backend_chain_params_0 = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain_0 = llama_sampler_chain_init(backend_chain_params_0); + llama_sampler_chain_add(backend_sampler_chain_0, llama_sampler_backend_init_temp(temp_0)); + + const float temp_1 = 0.1f; + struct llama_sampler_chain_params backend_chain_params_1 = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain_1 = llama_sampler_chain_init(backend_chain_params_1); + llama_sampler_chain_add(backend_sampler_chain_1, llama_sampler_backend_init_temp(temp_1)); + + std::vector backend_sampler_configs = { + { 0, backend_sampler_chain_0 }, + { 1, backend_sampler_chain_1 } + }; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{0, "Some where over"}, {1, "Once upon a"}})) { + return; + } + + int32_t batch_idx_0 = test_ctx.idx_for_seq(0); + int32_t batch_idx_1 = test_ctx.idx_for_seq(1); + + // Sample from sequence 0 using CPU sampler + struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params(); + struct llama_sampler * chain_0 = llama_sampler_chain_init(chain_params_0); + llama_sampler_chain_add(chain_0, llama_sampler_init_dist(18)); + + llama_token token_0 = llama_sampler_sample(chain_0, test_ctx.ctx, batch_idx_0); + const std::string token_0_str = test_ctx.token_to_piece(token_0, false); + printf("Sequence 0 sampled token id:%d, string: '%s'\n", token_0, token_0_str.c_str()); + GGML_ASSERT(token_0 >= 0 && token_0 < test_ctx.n_vocab); + + // Sample from sequence 1 using CPU sampler + struct llama_sampler_chain_params chain_params_1 = llama_sampler_chain_default_params(); + struct llama_sampler * chain_1 = llama_sampler_chain_init(chain_params_1); + llama_sampler_chain_add(chain_1, llama_sampler_init_dist(18)); + + llama_token token_1 = llama_sampler_sample(chain_1, test_ctx.ctx, batch_idx_1); + const std::string token_1_str = test_ctx.token_to_piece(token_1, false); + printf("Sequence 1 sampled token id:%d, string: '%s'\n", token_1, token_1_str.c_str()); + GGML_ASSERT(token_1 >= 0 && token_1 < test_ctx.n_vocab); + + printf("backend temp sampling test PASSED\n"); + + llama_sampler_free(chain_0); + llama_sampler_free(chain_1); +} + +static void test_backend_multi_sequence_sampling(const char * model_path) { + test_model_context test_ctx; + + struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params(); + struct llama_sampler * sampler_chain_0 = llama_sampler_chain_init(chain_params_0); + llama_sampler_chain_add(sampler_chain_0, llama_sampler_backend_init_greedy()); + + struct llama_sampler_chain_params chain_params_1 = llama_sampler_chain_default_params(); + struct llama_sampler * sampler_chain_1 = llama_sampler_chain_init(chain_params_1); + llama_sampler_chain_add(sampler_chain_1, llama_sampler_backend_init_temp(0.8f)); + llama_sampler_chain_add(sampler_chain_1, llama_sampler_backend_init_greedy()); + + std::vector backend_sampler_configs = { + { 0, sampler_chain_0 }, + { 1, sampler_chain_1 } + }; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + std::map prompts = { + {0, "Hello"}, + {1, "Some"} + }; + + if (!test_ctx.decode(prompts)) { + return; + } + + int32_t batch_idx_0 = test_ctx.idx_for_seq(0); + llama_token seq0_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx_0); + const std::string seq0_token_str = test_ctx.token_to_piece(seq0_token, false); + printf("Seq 0 sampled token id=%d, string='%s'\n", seq0_token, seq0_token_str.c_str()); + GGML_ASSERT(seq0_token >= 0 && seq0_token < test_ctx.n_vocab); + + int32_t batch_idx_1 = test_ctx.idx_for_seq(1); + llama_token seq1_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx_1); + const std::string seq1_token_str = test_ctx.token_to_piece(seq1_token, false); + printf("Seq 1 sampled token id=%d, string='%s'\n", seq1_token, seq1_token_str.c_str()); + GGML_ASSERT(seq1_token >= 0 && seq1_token < test_ctx.n_vocab); + + // Generate tokens for each sequence + printf("\nMulti-sequence generation:\n"); + for (int step = 0; step < 4; step++) { + std::map tokens; + + for (llama_seq_id seq_id : {0, 1}) { + int32_t idx = test_ctx.idx_for_seq(seq_id); + llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf(" Seq %d, step %d: token id=%d, string='%s'\n", seq_id, step, token, token_str.c_str()); + tokens[seq_id] = token; + } + + // Decode all tokens in a single batch + if (!test_ctx.decode_tokens(tokens)) { + break; + } + } + + printf("backend multi-sequence sampling test PASSED\n"); +} + +static void test_backend_dist_sampling(const char * model_path) { + test_model_context test_ctx; + + const int32_t seed = 88; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_dist(seed)); + std::vector backend_sampler_configs = {{ 0, backend_sampler_chain }}; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{0, "Hello"}})) { + return; + } + + llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, test_ctx.idx_for_seq(0)); + printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + + token = llama_get_backend_sampled_token_ith(test_ctx.ctx, -1); + printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); +} + +static void test_backend_dist_sampling_and_cpu(const char * model_path) { + test_model_context test_ctx; + + const int seq_id = 0; + const int32_t seed = 88; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_dist(seed)); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Hello"}})) { + return; + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + + // Sample using CPU sampler + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * chain = llama_sampler_chain_init(chain_params); + llama_sampler_chain_add(chain, llama_sampler_init_dist(18)); + + llama_token backend_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx); + llama_token cpu_token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); + GGML_ASSERT(backend_token == cpu_token); +} + +static void test_backend_logit_bias_sampling(const char * model_path) { + test_model_context test_ctx; + + // Calling setup_model to ensure vocab is loaded and can be accessed + if (!test_ctx.setup_model(model_path)) { + return; + } + + const int seq_id = 0; + + // Create the logit biases vector. + std::vector logit_bias; + + // Get the token for the piece "World". + const std::string piece = "World"; + std::vector tokens(16); + llama_tokenize(test_ctx.vocab, piece.c_str(), piece.size(), tokens.data(), tokens.size(), false, false); + llama_token bias_token = tokens[0]; + logit_bias.push_back({ bias_token, +100.0f }); + printf("biasing token piece '%s' -> token id %d\n", piece.c_str(), bias_token); + + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_logit_bias( + llama_vocab_n_tokens(test_ctx.vocab), + logit_bias.size(), + logit_bias.data())); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_dist(88)); + + std::vector backend_sampler_configs = { + { seq_id, backend_sampler_chain }, + }; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Hello"}})) { + return; + } + + llama_token backend_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, test_ctx.idx_for_seq(seq_id)); + const std::string backend_token_str = test_ctx.token_to_piece(backend_token, false); + printf("logit bias sampled token = %d, string='%s'\n", backend_token, backend_token_str.c_str()); + GGML_ASSERT(backend_token == bias_token); +} + +static void test_backend_set_sampler(const char * model_path) { + test_model_context test_ctx; + + const int32_t seed = 88; + const int seq_id = 0; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_dist(seed)); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Hello"}})) { + return; + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + + // Sample using backend sampler configured above + llama_token backend_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx); + const std::string backend_token_str = test_ctx.token_to_piece(backend_token, false); + printf("dist sampled token = %d, string='%s'\n", backend_token, backend_token_str.c_str()); + + // Now clear the backend sampler for this sequence. + llama_set_backend_sampler(test_ctx.ctx, seq_id, nullptr); + printf("Cleared backend sampler for seq_id %d\n", seq_id); + + // Sample using CPU sampler + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * chain = llama_sampler_chain_init(chain_params); + llama_sampler_chain_add(chain, llama_sampler_init_dist(18)); + + std::map tokens = { { seq_id, backend_token}, }; + if (!test_ctx.decode_tokens(tokens)) { + return; + } + + // Should not have any sampled token or probs after clearing the backend sampler. + const int32_t idx = test_ctx.idx_for_seq(seq_id); + GGML_ASSERT(llama_get_backend_sampled_token_ith(test_ctx.ctx, idx) == LLAMA_TOKEN_NULL); + GGML_ASSERT(llama_get_backend_sampled_probs_ith(test_ctx.ctx, idx) == nullptr); + + // Sample the token using the CPU sampler chain. + llama_token token2 = llama_sampler_sample(chain, test_ctx.ctx, seq_id); + const std::string token2_str = test_ctx.token_to_piece(token2, false); + printf("CPU sampled token after clearing backend sampler: id=%d, string='%s'\n", token2, token2_str.c_str()); + std::map tokens2 = { { seq_id, token2}, }; + + // Set a new backend sampler for the sequence. + struct llama_sampler_chain_params new_backend_chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * new_backend_sampler_chain = llama_sampler_chain_init(new_backend_chain_params); + llama_sampler_chain_add(new_backend_sampler_chain, llama_sampler_backend_init_top_k(20)); + llama_sampler_chain_add(new_backend_sampler_chain, llama_sampler_backend_init_dist(seed)); + llama_set_backend_sampler(test_ctx.ctx, seq_id, new_backend_sampler_chain); + + if (!test_ctx.decode_tokens(tokens2)) { + return; + } + + llama_token new_backend_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, test_ctx.idx_for_seq(seq_id)); + const std::string new_backend_token_str = test_ctx.token_to_piece(new_backend_token, false); + printf("dist sampled token = %d, string='%s'\n", new_backend_token, new_backend_token_str.c_str()); +} + +struct backend_test_case { + const char * name; + void (*fn)(const char *); + bool enabled_by_default; +}; + +static const backend_test_case BACKEND_TESTS[] = { + { "greedy", test_backend_greedy_sampling, true }, + { "logit_bias", test_backend_logit_bias_sampling, true }, + { "temp", test_backend_temp_sampling, true }, + { "top_k", test_backend_top_k_sampling, false }, + { "multi_sequence", test_backend_multi_sequence_sampling, false }, + { "dist", test_backend_dist_sampling, false }, + { "dist_and_cpu", test_backend_dist_sampling_and_cpu, false }, + { "set_sampler", test_backend_set_sampler, true }, +}; + +struct backend_cli_args { + const char * model = nullptr; + const char * test = nullptr; +}; + +static backend_cli_args parse_backend_cli(int argc, char ** argv) { + backend_cli_args out; + + for (int i = 1; i < argc; ++i) { + const char * arg = argv[i]; + + if (std::strcmp(arg, "--test") == 0) { + if (i + 1 >= argc) { + fprintf(stderr, "--test expects a value\n"); + exit(EXIT_FAILURE); + } + out.test = argv[++i]; + continue; + } + if (std::strncmp(arg, "--test=", 7) == 0) { + out.test = arg + 7; + continue; + } + if (std::strcmp(arg, "--model") == 0) { + if (i + 1 >= argc) { + fprintf(stderr, "--model expects a value\n"); + exit(EXIT_FAILURE); + } + out.model = argv[++i]; + continue; + } + if (std::strncmp(arg, "--model=", 8) == 0) { + out.model = arg + 8; + continue; + } + if (!out.model) { + out.model = arg; + continue; + } + + fprintf(stderr, "Unexpected argument: %s\n", arg); + exit(EXIT_FAILURE); + } + + return out; +} + +static std::vector collect_tests_to_run(const char * requested) { + std::vector selected; + + if (requested != nullptr) { + for (const auto & test : BACKEND_TESTS) { + if (std::strcmp(test.name, requested) == 0) { + selected.push_back(&test); + break; + } + } + if (selected.empty()) { + fprintf(stderr, "Unknown test '%s'. Available tests:\n", requested); + for (const auto & test : BACKEND_TESTS) { + fprintf(stderr, " %s\n", test.name); + } + exit(EXIT_FAILURE); + } + } else { + for (const auto & test : BACKEND_TESTS) { + if (test.enabled_by_default) { + selected.push_back(&test); + } + } + } + + if (selected.empty()) { + fprintf(stderr, "No backend sampling tests selected. Use --test= to pick one.\n"); + } + + return selected; +} + +static void run_tests(const std::vector & tests, const char * model_path) { + for (const auto * test : tests) { + fprintf(stderr, "\n=== %s ===\n", test->name); + test->fn(model_path); + } +} + + +int main(int argc, char *argv[] ) { + const backend_cli_args args = parse_backend_cli(argc, argv); + + std::array model_argv { argv[0], const_cast(args.model) }; + const int model_argc = args.model ? 2 : 1; + char * model_path = get_model_or_exit(model_argc, model_argv.data()); + + auto * file = fopen(model_path, "r"); + if (file == nullptr) { + fprintf(stderr, "no model at '%s' found\n", model_path); + return EXIT_FAILURE; + } + + fprintf(stderr, "using '%s'\n", model_path); + fclose(file); + + ggml_time_init(); + + const std::vector tests = collect_tests_to_run(args.test); + if (!tests.empty()) { + run_tests(tests, model_path); + } + + return 0; +}