diff --git a/common/sampling.cpp b/common/sampling.cpp index 94367bd307..ca3a0f0691 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -347,19 +347,11 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st is_backend = is_backend && (params.samplers.size() == 0 || common_sampler_type_has_backend_support(params.samplers[idx_smpl])); if (params.has_logit_bias()) { - if (is_backend) { - llama_sampler_chain_add(result->chain_backend, - llama_sampler_backend_init_logit_bias( - llama_vocab_n_tokens(vocab), - params.logit_bias.size(), - params.logit_bias.data())); - } else { - llama_sampler_chain_add(result->chain, - llama_sampler_init_logit_bias( - llama_vocab_n_tokens(vocab), - params.logit_bias.size(), - params.logit_bias.data())); - } + llama_sampler_chain_add(is_backend ? result->chain_backend : result->chain, + llama_sampler_init_logit_bias( + llama_vocab_n_tokens(vocab), + params.logit_bias.size(), + params.logit_bias.data())); } if (params.mirostat == 0) { @@ -375,16 +367,16 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st switch (cnstr) { case COMMON_SAMPLER_TYPE_TOP_K: - llama_sampler_chain_add(result->chain_backend, llama_sampler_backend_init_top_k(params.top_k)); + llama_sampler_chain_add(result->chain_backend, llama_sampler_init_top_k(params.top_k)); break; case COMMON_SAMPLER_TYPE_TEMPERATURE: - llama_sampler_chain_add(result->chain_backend, llama_sampler_backend_init_temp(params.temp)); + llama_sampler_chain_add(result->chain_backend, llama_sampler_init_temp(params.temp)); break; case COMMON_SAMPLER_TYPE_MIN_P: - llama_sampler_chain_add(result->chain_backend, llama_sampler_backend_init_min_p(params.min_p)); + llama_sampler_chain_add(result->chain_backend, llama_sampler_init_min_p(params.min_p, params.min_keep)); break; case COMMON_SAMPLER_TYPE_TOP_P: - llama_sampler_chain_add(result->chain_backend, llama_sampler_backend_init_top_p(params.top_p)); + llama_sampler_chain_add(result->chain_backend, llama_sampler_init_top_p(params.top_p, params.min_keep)); break; default: GGML_ASSERT(false && "unsupported backend sampler"); @@ -439,11 +431,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st } } - if (is_backend) { - llama_sampler_chain_add(result->chain_backend, llama_sampler_backend_init_dist(params.seed)); - } else { - llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); - } + llama_sampler_chain_add(is_backend ? result->chain_backend : result->chain, llama_sampler_init_dist(params.seed)); } else if (params.mirostat == 1) { llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index e23a3bab21..97cad5d260 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -74,9 +74,9 @@ int main(int argc, char ** argv) { llama_sampler * smpl = llama_sampler_chain_init(sparams); if (params.sampling.backend_sampling) { - llama_sampler_chain_add(smpl, llama_sampler_backend_init_top_k(params.sampling.top_k)); - llama_sampler_chain_add(smpl, llama_sampler_backend_init_temp (params.sampling.temp)); - llama_sampler_chain_add(smpl, llama_sampler_backend_init_dist (params.sampling.seed)); + llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k)); + llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp)); + llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed)); } else { llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k)); llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep)); diff --git a/include/llama.h b/include/llama.h index 38178d919f..263733cf2c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1388,31 +1388,9 @@ extern "C" { // Backend samplers // - /// @details Greedy sampling on backend - always selects the token with the highest probability - LLAMA_API struct llama_sampler * llama_sampler_backend_init_greedy(void); - - /// @details Temperature scaling on backend - scales logits by 1/temperature - LLAMA_API struct llama_sampler * llama_sampler_backend_init_temp(float temp); - - /// @details Top-K filtering on backend - keeps only the k tokens with highest probabilities - LLAMA_API struct llama_sampler * llama_sampler_backend_init_top_k(int32_t k); - - /// @details Distribution sampling on backend - final sampling step that selects a token - LLAMA_API struct llama_sampler * llama_sampler_backend_init_dist(uint32_t seed); - - /// @details Min-P filtering on backend - filter tokens with a probability less than p times the maximum probability. - LLAMA_API struct llama_sampler * llama_sampler_backend_init_min_p(float p); - - /// @details Top-p filtering on backend - filter all tokens with cumulative pseudo-probability less than p. - LLAMA_API struct llama_sampler * llama_sampler_backend_init_top_p(float p); - // Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl); - LLAMA_API struct llama_sampler * llama_sampler_backend_init_logit_bias(int32_t n_vocab, - int32_t n_logit_bias, - const llama_logit_bias * logit_bias); - /// @details Sample and accept a token from the idx-th output of the last evaluation // // Shorthand for: diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 97320fe97d..f7a8c9841e 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -31,7 +31,6 @@ add_library(llama llama-model.cpp llama-quant.cpp llama-sampling.cpp - llama-backend-sampler.cpp llama-vocab.cpp unicode-data.cpp unicode.cpp diff --git a/src/llama-backend-sampler.cpp b/src/llama-backend-sampler.cpp deleted file mode 100644 index 20a52866db..0000000000 --- a/src/llama-backend-sampler.cpp +++ /dev/null @@ -1,711 +0,0 @@ -#include "llama.h" -#include "ggml.h" -#include -#include -#include -#include -#include - -static void llama_sampler_backend_greedy_apply_ggml( - struct llama_sampler * smpl, - struct ggml_context * ctx, - struct ggml_cgraph * gf, - struct llama_sampler_ggml_data * ggml_data) { - GGML_UNUSED(gf); - GGML_UNUSED(smpl); - struct ggml_tensor * argmax_result = ggml_argmax(ctx, ggml_data->logits); - ggml_set_name(argmax_result, "argmax_result"); - ggml_data->sampled = argmax_result; -} - -static const char * llama_sampler_backend_greedy_sampler_name(const struct llama_sampler *) { - return "test-ggml"; -} - -static struct llama_sampler * llama_sampler_backend_greedy_clone(const struct llama_sampler * smpl) { - (void) smpl; - return llama_sampler_backend_init_greedy(); -} - -struct llama_sampler * llama_sampler_backend_init_greedy() { - static const llama_sampler_i iface = { - /*.name =*/ llama_sampler_backend_greedy_sampler_name, - /*.accept =*/ nullptr, - /*.apply =*/ nullptr, - /*.reset =*/ nullptr, - /*.clone =*/ llama_sampler_backend_greedy_clone, - /*.free =*/ nullptr, - /*.apply_ggml =*/ llama_sampler_backend_greedy_apply_ggml, - /*.accept_ggml =*/ nullptr, - /*.set_input_ggml =*/ nullptr, - /*.init_ggml =*/ nullptr, - }; - - auto * sampler = new llama_sampler { - /*.iface =*/ &iface, - /*.ctx =*/ nullptr, - }; - - return sampler; -} - -struct llama_sampler_backend_temp_ctx { - float temp; -}; - - -static void llama_sampler_backend_temp_apply_ggml( - struct llama_sampler * smpl, - struct ggml_context * ctx, - struct ggml_cgraph * gf, - struct llama_sampler_ggml_data * ggml_data) { - - auto * ctx_data = (llama_sampler_backend_temp_ctx *) smpl->ctx; - - if (ctx_data->temp <= 0.0f) { - return; - } - - struct ggml_tensor * scaled = ggml_scale(ctx, ggml_data->logits, 1.0f / ctx_data->temp); - ggml_set_name(scaled, "temp_scaled"); - - // Make sure the scaled tensor is contiguous for subsequent operations - ggml_data->logits = ggml_cont(ctx, scaled); - ggml_set_name(ggml_data->logits, "temp_scaled_logits"); - - ggml_build_forward_expand(gf, ggml_data->logits); -} - -static const char * llama_sampler_backend_temp_name(const struct llama_sampler *) { - return "backend-temp"; -} - -static void llama_sampler_backend_temp_free(struct llama_sampler * smpl) { - auto * ctx_data = (llama_sampler_backend_temp_ctx *) smpl->ctx; - delete ctx_data; -} - -static struct llama_sampler * llama_sampler_backend_temp_clone(const struct llama_sampler * smpl) { - auto * ctx = (llama_sampler_backend_temp_ctx *) smpl->ctx; - return llama_sampler_backend_init_temp(ctx->temp); -} - -struct llama_sampler * llama_sampler_backend_init_temp(float temp) { - static const llama_sampler_i iface = { - /*.name =*/ llama_sampler_backend_temp_name, - /*.accept =*/ nullptr, - /*.apply =*/ nullptr, - /*.reset =*/ nullptr, - /*.clone =*/ llama_sampler_backend_temp_clone, - /*.free =*/ llama_sampler_backend_temp_free, - /*.apply_ggml =*/ llama_sampler_backend_temp_apply_ggml, - /*.accept_ggml =*/ nullptr, - /*.set_input_ggml =*/ nullptr, - /*.set_backend_context =*/ nullptr, - }; - - auto * ctx_data = new llama_sampler_backend_temp_ctx { - /*.temp =*/ temp, - }; - - auto * sampler = new llama_sampler { - /*.iface =*/ &iface, - /*.ctx =*/ ctx_data, - }; - - return sampler; -} - - -struct llama_sampler_backend_top_k_ctx { - int32_t k; - - // Only required for checking operation support and can be removed later. - ggml_backend_dev_t device; -}; - -static void llama_sampler_backend_top_k_init_ggml( - struct llama_sampler * smpl, - ggml_backend_buffer_type_t buft) { - auto * ctx_data = (llama_sampler_backend_top_k_ctx *) smpl->ctx; - ctx_data->device = ggml_backend_buft_get_device(buft); -} - -static void llama_sampler_backend_top_k_apply_ggml( - struct llama_sampler * smpl, - struct ggml_context * ctx, - struct ggml_cgraph * gf, - struct llama_sampler_ggml_data * ggml_data) { - - auto * ctx_data = (llama_sampler_backend_top_k_ctx *) smpl->ctx; - - struct ggml_tensor * top_k = ggml_top_k(ctx, ggml_data->logits, ctx_data->k); - ggml_set_name(top_k, "top_k"); - - // top_k is a view of argsort - check if backend supports the underlying argsort operation - // by checking the source tensor (which is the argsort result) - if (ctx_data->device && top_k->src[0] && !ggml_backend_dev_supports_op(ctx_data->device, top_k->src[0])) { - fprintf(stderr, "Warning: backend does not support argsort operation required for top-k sampling\n"); - fprintf(stderr, "CPU backend will be used instead which defeats the purpose of having backend samplers\n"); - } - - ggml_data->candidates = top_k; - - struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, ggml_data->logits, 1, ggml_data->logits->ne[0]); - struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k); - ggml_set_name(top_k_rows, "top_k_rows"); - - ggml_data->logits = ggml_reshape_1d(ctx, top_k_rows, ctx_data->k); - ggml_build_forward_expand(gf, ggml_data->logits); -} - -static const char * llama_sampler_backend_top_k_name(const struct llama_sampler *) { - return "backend-top-k"; -} - -static void llama_sampler_backend_top_k_free(struct llama_sampler * smpl) { - auto * ctx_data = (llama_sampler_backend_top_k_ctx *) smpl->ctx; - delete ctx_data; -} - -static struct llama_sampler * llama_sampler_backend_top_k_clone(const struct llama_sampler * smpl) { - auto * ctx = (llama_sampler_backend_top_k_ctx *) smpl->ctx; - return llama_sampler_backend_init_top_k(ctx->k); -} - -struct llama_sampler * llama_sampler_backend_init_top_k(int32_t k) { - static const llama_sampler_i iface = { - /*.name =*/ llama_sampler_backend_top_k_name, - /*.accept =*/ nullptr, - /*.apply =*/ nullptr, - /*.reset =*/ nullptr, - /*.clone =*/ llama_sampler_backend_top_k_clone, - /*.free =*/ llama_sampler_backend_top_k_free, - /*.apply_ggml =*/ llama_sampler_backend_top_k_apply_ggml, - /*.accept_ggml =*/ nullptr, - /*.set_input_ggml =*/ nullptr, - /*.init_ggml =*/ llama_sampler_backend_top_k_init_ggml, - }; - - auto * ctx_data = new llama_sampler_backend_top_k_ctx { - /*.k =*/ k, - /*.device =*/ nullptr, - }; - - auto * sampler = new llama_sampler { - /*.iface =*/ &iface, - /*.ctx =*/ ctx_data, - }; - - return sampler; -} - - -static uint32_t get_rng_seed(uint32_t seed) { - if (seed == LLAMA_DEFAULT_SEED) { - // use system clock if std::random_device is not a true RNG - static bool is_rd_prng = std::random_device().entropy() == 0; - if (is_rd_prng) { - return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count(); - } - std::random_device rd; - return rd(); - } - return seed; -} - -struct llama_sampler_backend_dist_ctx { - const uint32_t seed; - uint32_t seed_cur; - std::mt19937 rng; - - struct ggml_tensor * uniform; - struct ggml_context * ctx; - ggml_backend_buffer_t buffer; - - // Only required for checking operation support and can be removed later. - ggml_backend_dev_t device; -}; - -static void llama_sampler_backend_dist_init_ggml( - struct llama_sampler * smpl, - ggml_backend_buffer_type_t buft) { - - auto * sctx = (llama_sampler_backend_dist_ctx *) smpl->ctx; - sctx->device = ggml_backend_buft_get_device(buft); - ggml_init_params params = { - /*.mem_size =*/ ggml_tensor_overhead(), - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, - }; - sctx->ctx = ggml_init(params); - - // Create the uniform random scalar input tensor. This will be set by - // llama_sampler_backend_dist_set_input_ggml after this graph is built. - sctx->uniform = ggml_new_tensor_1d(sctx->ctx, GGML_TYPE_F32, 1); - ggml_set_name(sctx->uniform, "uniform"); - ggml_set_input(sctx->uniform); - ggml_set_output(sctx->uniform); - - // Allocate all tensors from our context to the backend - sctx->buffer = ggml_backend_alloc_ctx_tensors_from_buft(sctx->ctx, buft); -} - -static void llama_sampler_backend_dist_set_input_ggml(struct llama_sampler * smpl) { - auto * sctx = (llama_sampler_backend_dist_ctx *) smpl->ctx; - GGML_ASSERT(sctx->uniform != nullptr); - - std::uniform_real_distribution dist(0.0f, 1.0f); - const float rnd = dist(sctx->rng); - ggml_backend_tensor_set(sctx->uniform, &rnd, 0, sizeof(float)); -} - -static void llama_sampler_backend_dist_apply_ggml( - struct llama_sampler * smpl, - struct ggml_context * ctx, - struct ggml_cgraph * gf, - struct llama_sampler_ggml_data * ggml_data) { - GGML_UNUSED(gf); - auto * sctx = (llama_sampler_backend_dist_ctx *) smpl->ctx; - - struct ggml_tensor * probs = ggml_soft_max(ctx, ggml_data->logits); - ggml_set_name(probs, "dist_probs"); - - struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs); - if (sctx->device && !ggml_backend_dev_supports_op(sctx->device, cumsum)) { - fprintf(stderr, "Warning: backend does not support cumsum operation required for dist sampling\n"); - fprintf(stderr, "CPU backend will be used instead which defeats the purpose of having backend samplers\n"); - } - ggml_set_name(cumsum, "cumsum"); - - // The uniform tensor has a random value and we subtract this tensor with - // the cumsum tensor (the uniform tensor will be broadcasted by ggml_sub). - // Recall that each entry in cumsum is the cumulative probability up to that - // index so values stay negative while the cumulative total is below the - // random value, and become zero/positive once the threshold is crossed. - struct ggml_tensor * diff = ggml_sub(ctx, cumsum, sctx->uniform); - ggml_set_name(diff, "dist_cumsum"); - - // The ggml_step function produces a tensor where entries are 1 if the - // corresponding entry in diff is > 0, and 0 otherwise. So all values up to - // the index where the cumulative probability exceeds the random value are 0, - // and all entries after that are 1. - struct ggml_tensor * mask = ggml_step(ctx, diff); - ggml_set_name(mask, "dist_mask"); - - // Taking the sum of the mask gives us the sum of elements after the threshold - // we are interested in. - struct ggml_tensor * idxf = ggml_sum(ctx, mask); - ggml_set_name(idxf, "dist_index_f32"); - - // Use ggml_scale_bias to scale the index value by -1 and then add the size - // of the mask to that value so we get the correct index ((-1 * idxf) + n). - struct ggml_tensor * idx = ggml_cast(ctx, ggml_scale_bias(ctx, idxf, -1.0f, mask->ne[0]), GGML_TYPE_I32); - ggml_set_name(idx, "dist_index_i32"); - - // Map back to original vocab ids if a candidates tensor is available. - struct ggml_tensor * sampled_token = idx; - if (ggml_data->candidates != nullptr) { - struct ggml_tensor * candidates = ggml_data->candidates; - struct ggml_tensor * candidates_reshaped = ggml_view_2d(ctx, candidates, 1, ggml_nelements(candidates), - ggml_type_size(candidates->type), 0); - - sampled_token = ggml_get_rows(ctx, candidates_reshaped, idx); - ggml_set_name(sampled_token, "dist_sampled_token"); - } - - ggml_set_output(sampled_token); - ggml_data->sampled = sampled_token; -} - -static const char * llama_sampler_backend_dist_name(const struct llama_sampler *) { - return "backend-dist"; -} - -static void llama_sampler_backend_dist_free(struct llama_sampler * smpl) { - auto * sctx = (llama_sampler_backend_dist_ctx *) smpl->ctx; - ggml_backend_buffer_free(sctx->buffer); - ggml_free(sctx->ctx); - delete sctx; -} - -static struct llama_sampler * llama_sampler_backend_dist_clone(const struct llama_sampler * smpl) { - auto * sctx = (llama_sampler_backend_dist_ctx *) smpl->ctx; - return llama_sampler_backend_init_dist(sctx->seed); -} - - -struct llama_sampler * llama_sampler_backend_init_dist(uint32_t seed) { - static const llama_sampler_i iface = { - /*.name =*/ llama_sampler_backend_dist_name, - /*.accept =*/ nullptr, - /*.apply =*/ nullptr, - /*.reset =*/ nullptr, - /*.clone =*/ llama_sampler_backend_dist_clone, - /*.free =*/ llama_sampler_backend_dist_free, - /*.apply_ggml =*/ llama_sampler_backend_dist_apply_ggml, - /*.accept_ggml =*/ nullptr, - /*.set_input_ggml =*/ llama_sampler_backend_dist_set_input_ggml, - /*.init_ggml =*/ llama_sampler_backend_dist_init_ggml, - }; - - auto seed_cur = get_rng_seed(seed); - auto * ctx_data = new llama_sampler_backend_dist_ctx { - /*.seed =*/ seed, - /*.seed_cur =*/ seed_cur, - /*.rng =*/ std::mt19937(seed_cur), - /*.uniform =*/ nullptr, - /*.ctx =*/ nullptr, - /*.buffer =*/ nullptr, - /*.device =*/ nullptr, - }; - - auto * sampler = new llama_sampler { - /*.iface =*/ &iface, - /*.ctx =*/ ctx_data, - }; - - return sampler; -} - -struct llama_sampler_backend_logit_bias_ctx { - const int32_t n_vocab; - - const std::vector logit_bias; - - struct ggml_tensor * logit_bias_t; - struct ggml_context * ctx; - ggml_backend_buffer_t buffer; -}; - -static void llama_sampler_backend_logit_bias_init_ggml( - struct llama_sampler * smpl, - ggml_backend_buffer_type_t buft) { - auto * sctx = (llama_sampler_backend_logit_bias_ctx *) smpl->ctx; - if (sctx->logit_bias.empty()) { - return; - } - ggml_init_params params = { - /*.mem_size =*/ ggml_tensor_overhead() * sctx->n_vocab * sizeof(float), - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, - }; - sctx->ctx = ggml_init(params); - - struct ggml_tensor * logit_bias = ggml_new_tensor_1d(sctx->ctx, GGML_TYPE_F32, sctx->n_vocab); - sctx->logit_bias_t = logit_bias; - ggml_set_name(sctx->logit_bias_t, "logit_bias"); - ggml_set_input(sctx->logit_bias_t); - ggml_set_output(sctx->logit_bias_t); - - // Allocate all tensors from our context to the backend - sctx->buffer = ggml_backend_alloc_ctx_tensors_from_buft(sctx->ctx, buft); -} - -static void llama_sampler_backend_logit_bias_set_input_ggml(struct llama_sampler * smpl) { - auto * sctx = (llama_sampler_backend_logit_bias_ctx *) smpl->ctx; - if (sctx->logit_bias.empty()) { - return; - } - GGML_ASSERT(sctx->logit_bias_t != nullptr); - - // Create a sparse logit_bias vector from the logit_bias entries. - std::vector logit_bias_sparse(sctx->n_vocab, 0.0f); - for (const auto & lb : sctx->logit_bias) { - GGML_ASSERT(lb.token >= 0 && lb.token < (int32_t) sctx->n_vocab); - logit_bias_sparse[lb.token] = lb.bias; - } - - ggml_backend_tensor_set(sctx->logit_bias_t, logit_bias_sparse.data(), 0, ggml_nbytes(sctx->logit_bias_t)); -} - -static void llama_sampler_backend_logit_bias_apply_ggml( - struct llama_sampler * smpl, - struct ggml_context * ctx, - struct ggml_cgraph * gf, - struct llama_sampler_ggml_data * ggml_data) { - GGML_UNUSED(gf); - GGML_UNUSED(ctx); - - auto * sctx = (llama_sampler_backend_logit_bias_ctx *) smpl->ctx; - if (sctx->logit_bias_t == nullptr) { - return; - } - - // Add the sparse logit logit_bias to the logits - struct ggml_tensor * logit_biased = ggml_add_inplace(sctx->ctx, ggml_data->logits, sctx->logit_bias_t); - ggml_build_forward_expand(gf, logit_biased); -} - -static const char * llama_sampler_backend_logit_bias_name(const struct llama_sampler *) { - return "backend-logit_bias"; -} - -static void llama_sampler_backend_logit_bias_free(struct llama_sampler * smpl) { - auto * sctx = (llama_sampler_backend_logit_bias_ctx *) smpl->ctx; - ggml_backend_buffer_free(sctx->buffer); - ggml_free(sctx->ctx); - delete sctx; -} - -static struct llama_sampler * llama_sampler_backend_logit_bias_clone(const struct llama_sampler * smpl) { - auto * sctx = (llama_sampler_backend_logit_bias_ctx *) smpl->ctx; - return llama_sampler_backend_init_logit_bias(sctx->n_vocab, - sctx->logit_bias.size(), - sctx->logit_bias.data()); -} - - -struct llama_sampler * llama_sampler_backend_init_logit_bias(int32_t n_vocab, - int32_t n_logit_bias, - const llama_logit_bias * logit_bias) { - static const llama_sampler_i iface = { - /*.name =*/ llama_sampler_backend_logit_bias_name, - /*.accept =*/ nullptr, - /*.apply =*/ nullptr, - /*.reset =*/ nullptr, - /*.clone =*/ llama_sampler_backend_logit_bias_clone, - /*.free =*/ llama_sampler_backend_logit_bias_free, - /*.apply_ggml =*/ llama_sampler_backend_logit_bias_apply_ggml, - /*.accept_ggml =*/ nullptr, - /*.set_input_ggml =*/ llama_sampler_backend_logit_bias_set_input_ggml, - /*.init_ggml =*/ llama_sampler_backend_logit_bias_init_ggml, - }; - - auto * ctx_data = new llama_sampler_backend_logit_bias_ctx { - /*.n_vocab =*/ n_vocab, - /*.logit_bias =*/ std::vector(logit_bias, logit_bias + n_logit_bias), - /*.logit_bias_t =*/ nullptr, - /*.ctx =*/ nullptr, - /*.buffer =*/ nullptr, - }; - - auto * sampler = new llama_sampler { - /*.iface =*/ &iface, - /*.ctx =*/ ctx_data, - }; - - return sampler; -} - -struct llama_sampler_backend_min_p_ctx { - float p; - - // Only required for checking operation support and can be removed later. - ggml_backend_dev_t device; -}; - -static void llama_sampler_backend_min_p_init_ggml( - struct llama_sampler * smpl, - ggml_backend_buffer_type_t buft) { - auto * sctx = (llama_sampler_backend_min_p_ctx *) smpl->ctx; - sctx->device = ggml_backend_buft_get_device(buft); -} - -static void llama_sampler_backend_min_p_apply_ggml( - struct llama_sampler * smpl, - struct ggml_context * ctx, - struct ggml_cgraph * gf, - struct llama_sampler_ggml_data * ggml_data) { - GGML_UNUSED(gf); - - auto * sctx = (llama_sampler_backend_min_p_ctx *) smpl->ctx; - - struct ggml_tensor * max_idx = ggml_argmax(ctx, ggml_data->logits); - ggml_set_name(max_idx, "max_idx"); - - struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, ggml_data->logits, 1, ggml_data->logits->ne[0]); - ggml_set_name(logits_rows, "logits_rows"); - - struct ggml_tensor * max_logit = ggml_get_rows(ctx, logits_rows, max_idx); - ggml_set_name(max_logit, "max_logit"); - - // Calculate the threshold value. - struct ggml_tensor * threshold = ggml_scale_bias(ctx, max_logit, 1.0f, logf(sctx->p)); - ggml_set_name(threshold, "min_p_threshold"); - - // Subtract the threshold from logits. - struct ggml_tensor * sub = ggml_sub(ctx, ggml_data->logits, threshold); - - // Create a mask where logits below the threshold are 0 (discard), - // and others are 1 (keep). - struct ggml_tensor * mask = ggml_step(ctx, sub); - ggml_set_name(mask, "min_p_mask"); - - // Use ggml_scale_bias (output = (a * s) + b) which in this case becomes: - // min_p_bias = (mask * 1e9f) - 1e9f. - // So entries in the mask that we want to discard will become -1e9f, and - // others will be 0 (meaning that will not effect the logits). - const float large_val = 1e9f; - struct ggml_tensor * min_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val); - ggml_set_name(min_p_bias, "min_p_bias"); - - // Add the min_p bias to the logits. - ggml_data->logits = ggml_add(ctx, ggml_data->logits, min_p_bias); - ggml_set_name(ggml_data->logits, "min_p_logits"); - - ggml_build_forward_expand(gf, ggml_data->logits); -} - -static const char * llama_sampler_backend_min_p_name(const struct llama_sampler *) { - return "backend-min-p"; -} - -static void llama_sampler_backend_min_p_free(struct llama_sampler * smpl) { - auto * sctx = (llama_sampler_backend_min_p_ctx *) smpl->ctx; - delete sctx; -} - -static struct llama_sampler * llama_sampler_backend_min_p_clone(const struct llama_sampler * smpl) { - auto * sctx = (llama_sampler_backend_min_p_ctx *) smpl->ctx; - return llama_sampler_backend_init_min_p(sctx->p); -} - -struct llama_sampler * llama_sampler_backend_init_min_p(float p) { - static const llama_sampler_i iface = { - /*.name =*/ llama_sampler_backend_min_p_name, - /*.accept =*/ nullptr, - /*.apply =*/ nullptr, - /*.reset =*/ nullptr, - /*.clone =*/ llama_sampler_backend_min_p_clone, - /*.free =*/ llama_sampler_backend_min_p_free, - /*.apply_ggml =*/ llama_sampler_backend_min_p_apply_ggml, - /*.accept_ggml =*/ nullptr, - /*.set_input_ggml =*/ nullptr, - /*.init_ggml =*/ llama_sampler_backend_min_p_init_ggml, - }; - - auto * sctx = new llama_sampler_backend_min_p_ctx { - /*.p =*/ p, - /*.device =*/ nullptr, - }; - - auto * sampler = new llama_sampler { - /*.iface =*/ &iface, - /*.ctx =*/ sctx, - }; - - return sampler; -} - - -struct llama_sampler_backend_top_p_ctx { - float p; - - // Only required for checking operation support and can be removed later. - ggml_backend_dev_t device; -}; - - -static void llama_sampler_backend_top_p_init_ggml( - struct llama_sampler * smpl, - ggml_backend_buffer_type_t buft) { - auto * sctx = (llama_sampler_backend_top_p_ctx *) smpl->ctx; - sctx->device = ggml_backend_buft_get_device(buft); -} - -static void llama_sampler_backend_top_p_apply_ggml( - struct llama_sampler * smpl, - struct ggml_context * ctx, - struct ggml_cgraph * gf, - struct llama_sampler_ggml_data * ggml_data) { - GGML_UNUSED(gf); - - auto * sctx = (llama_sampler_backend_top_p_ctx *) smpl->ctx; - - struct ggml_tensor * softmax = ggml_soft_max(ctx, ggml_data->logits); - ggml_set_name(softmax, "top_p_softmax"); - - // Get the sorted indices of the softmax probabilities in descending order. - struct ggml_tensor * sorted_idx = ggml_argsort(ctx, softmax, GGML_SORT_ORDER_DESC); - ggml_set_name(sorted_idx, "top_p_sorted_idx"); - - // Do the sorting via reshape + get_rows - struct ggml_tensor * softmax_reshaped = ggml_reshape_2d(ctx, softmax, 1, softmax->ne[0]); - ggml_set_name(softmax_reshaped, "top_p_softmax_reshaped"); - - struct ggml_tensor * sorted_probs = ggml_get_rows(ctx, softmax_reshaped, sorted_idx); - ggml_set_name(sorted_probs, "top_p_sorted_probs"); - - struct ggml_tensor * sorted_probs_reshaped = ggml_reshape_2d(ctx, sorted_probs, softmax->ne[0], 1); - ggml_set_name(sorted_probs_reshaped, "top_p_sorted_probs_reshaped"); - // Compute Cumulative Distribution Function (CDF) by means of GGML_OP_CUMSUM. - struct ggml_tensor * sorted_cdf = ggml_cumsum(ctx, sorted_probs_reshaped); - ggml_set_name(sorted_cdf, "top_p_sorted_cdf"); - - // Invert CDF and add top-p value so that ggml_step yields 1 for values we want to keep - struct ggml_tensor * sorted_cdf_scaled = ggml_scale_bias(ctx, sorted_cdf, -1.0f, sctx->p); - ggml_set_name(sorted_cdf_scaled, "top_p_sorted_cdf_scaled"); - - struct ggml_tensor * sorted_mask = ggml_step(ctx, sorted_cdf_scaled); - ggml_set_name(sorted_mask, "top_p_sorted_mask"); - - // reverse sorting by argsort(argsort) - // cast to F32 since cuda only supports float inputs - struct ggml_tensor * reverse_argsort = ggml_argsort(ctx, ggml_cast(ctx, sorted_idx, GGML_TYPE_F32), GGML_SORT_ORDER_ASC); - ggml_set_name(reverse_argsort, "top_p_reverse_argsort"); - - // Do the sorting via reshape + get_rows - struct ggml_tensor * sorted_reshaped_mask = ggml_reshape_2d(ctx, sorted_mask, 1, sorted_mask->ne[0]); - ggml_set_name(sorted_reshaped_mask, "top_p_sorted_reshaped_mask"); - - struct ggml_tensor * reshaped_mask = ggml_get_rows(ctx, sorted_reshaped_mask, reverse_argsort); - ggml_set_name(reshaped_mask, "top_p_reshaped_mask"); - - struct ggml_tensor * mask = ggml_reshape_2d(ctx, reshaped_mask, sorted_mask->ne[0], 1); - ggml_set_name(mask, "top_p_mask"); - - // Use ggml_scale_bias (output = (a * s) + b) which in this case becomes: - // top_p_bias = (mask * 1e9f) - 1e9f. - // So entries in the mask that we want to discard will become -1e9f, and - // others will be 0 (meaning that will not effect the logits). - const float large_val = 1e9f; - struct ggml_tensor * top_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val); - ggml_set_name(top_p_bias, "top_p_bias"); - - ggml_data->logits = ggml_add(ctx, ggml_data->logits, top_p_bias); - ggml_set_name(ggml_data->logits, "top_p_logits"); - - ggml_build_forward_expand(gf, ggml_data->logits); -} - -static const char * llama_sampler_backend_top_p_name(const struct llama_sampler *) { - return "backend-top-p"; -} - -static void llama_sampler_backend_top_p_free(struct llama_sampler * smpl) { - auto * sctx = (llama_sampler_backend_top_p_ctx *) smpl->ctx; - delete sctx; -} - -static struct llama_sampler * llama_sampler_backend_top_p_clone(const struct llama_sampler * smpl) { - auto * sctx = (llama_sampler_backend_top_p_ctx *) smpl->ctx; - return llama_sampler_backend_init_top_p(sctx->p); -} - -struct llama_sampler * llama_sampler_backend_init_top_p(float p) { - static const llama_sampler_i iface = { - /*.name =*/ llama_sampler_backend_top_p_name, - /*.accept =*/ nullptr, - /*.apply =*/ nullptr, - /*.reset =*/ nullptr, - /*.clone =*/ llama_sampler_backend_top_p_clone, - /*.free =*/ llama_sampler_backend_top_p_free, - /*.apply_ggml =*/ llama_sampler_backend_top_p_apply_ggml, - /*.accept_ggml =*/ nullptr, - /*.set_input_ggml =*/ nullptr, - /*.init_ggml =*/ llama_sampler_backend_top_p_init_ggml, - }; - - auto * sctx = new llama_sampler_backend_top_p_ctx { - /*.p =*/ p, - /*.device =*/ nullptr, - }; - - auto * sampler = new llama_sampler { - /*.iface =*/ &iface, - /*.ctx =*/ sctx, - }; - - return sampler; -} diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 4d1760a629..a13be03240 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -4,6 +4,8 @@ #include "llama-vocab.h" #include "llama-grammar.h" +#include "ggml-cpp.h" + #include #include #include @@ -610,16 +612,16 @@ static void llama_sampler_chain_set_backend_context( } static struct llama_sampler_i llama_sampler_chain_i = { - /* .name = */ llama_sampler_chain_name, - /* .accept = */ llama_sampler_chain_accept, - /* .apply = */ llama_sampler_chain_apply, - /* .reset = */ llama_sampler_chain_reset, - /* .clone = */ llama_sampler_chain_clone, - /* .free = */ llama_sampler_chain_free, - /* .apply_ggml = */ llama_sampler_chain_apply_ggml, - /* .accept_ggml = */ llama_sampler_chain_accept_ggml, - /* .set_input_ggml = */ llama_sampler_chain_set_input_ggml, - /* .set_backend_context = */ llama_sampler_chain_set_backend_context, + /* .name = */ llama_sampler_chain_name, + /* .accept = */ llama_sampler_chain_accept, + /* .apply = */ llama_sampler_chain_apply, + /* .reset = */ llama_sampler_chain_reset, + /* .clone = */ llama_sampler_chain_clone, + /* .free = */ llama_sampler_chain_free, + /* .apply_ggml = */ llama_sampler_chain_apply_ggml, + /* .accept_ggml = */ llama_sampler_chain_accept_ggml, + /* .set_input_ggml = */ llama_sampler_chain_set_input_ggml, + /* .init_ggml = */ llama_sampler_chain_set_backend_context, }; struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) { @@ -687,17 +689,29 @@ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_to } } +static void llama_sampler_greedy_apply_ggml( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_ggml_data * ggml_data) { + GGML_UNUSED(gf); + GGML_UNUSED(smpl); + struct ggml_tensor * argmax_result = ggml_argmax(ctx, ggml_data->logits); + ggml_set_name(argmax_result, "argmax_result"); + ggml_data->sampled = argmax_result; +} + static struct llama_sampler_i llama_sampler_greedy_i = { - /* .name = */ llama_sampler_greedy_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_greedy_apply, - /* .reset = */ nullptr, - /* .clone = */ nullptr, - /* .free = */ nullptr, - /* .apply_ggml = */ nullptr, - /* .accept_ggml = */ nullptr, - /* .set_input_ggml = */ nullptr, - /* .set_backend_context = */ nullptr, + /* .name = */ llama_sampler_greedy_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_greedy_apply, + /* .reset = */ nullptr, + /* .clone = */ nullptr, + /* .free = */ nullptr, + /* .apply_ggml = */ llama_sampler_greedy_apply_ggml, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .init_ggml = */ nullptr, }; struct llama_sampler * llama_sampler_init_greedy() { @@ -714,6 +728,14 @@ struct llama_sampler_dist { uint32_t seed_cur; std::mt19937 rng; + + // Only required for checking operation support and can be removed later. + ggml_backend_dev_t device; + + struct ggml_tensor * inp_uniform; + + ggml_context_ptr inp_ctx; + ggml_backend_buffer_ptr inp_buf; }; static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) { @@ -816,17 +838,109 @@ static void llama_sampler_dist_free(struct llama_sampler * smpl) { delete (llama_sampler_dist *) smpl->ctx; } +static void llama_sampler_dist_apply_ggml( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_ggml_data * ggml_data) { + GGML_UNUSED(gf); + auto * sctx = (llama_sampler_dist *) smpl->ctx; + + struct ggml_tensor * probs = ggml_soft_max(ctx, ggml_data->logits); + ggml_set_name(probs, "dist_probs"); + + struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs); + if (sctx->device && !ggml_backend_dev_supports_op(sctx->device, cumsum)) { + fprintf(stderr, "Warning: backend does not support cumsum operation required for dist sampling\n"); + fprintf(stderr, "CPU backend will be used instead which defeats the purpose of having backend samplers\n"); + } + ggml_set_name(cumsum, "cumsum"); + + // The uniform tensor has a random value and we subtract this tensor with + // the cumsum tensor (the uniform tensor will be broadcasted by ggml_sub). + // Recall that each entry in cumsum is the cumulative probability up to that + // index so values stay negative while the cumulative total is below the + // random value, and become zero/positive once the threshold is crossed. + struct ggml_tensor * diff = ggml_sub(ctx, cumsum, sctx->inp_uniform); + ggml_set_name(diff, "dist_cumsum"); + + // The ggml_step function produces a tensor where entries are 1 if the + // corresponding entry in diff is > 0, and 0 otherwise. So all values up to + // the index where the cumulative probability exceeds the random value are 0, + // and all entries after that are 1. + struct ggml_tensor * mask = ggml_step(ctx, diff); + ggml_set_name(mask, "dist_mask"); + + // Taking the sum of the mask gives us the sum of elements after the threshold + // we are interested in. + struct ggml_tensor * idxf = ggml_sum(ctx, mask); + ggml_set_name(idxf, "dist_index_f32"); + + // Use ggml_scale_bias to scale the index value by -1 and then add the size + // of the mask to that value so we get the correct index ((-1 * idxf) + n). + struct ggml_tensor * idx = ggml_cast(ctx, ggml_scale_bias(ctx, idxf, -1.0f, mask->ne[0]), GGML_TYPE_I32); + ggml_set_name(idx, "dist_index_i32"); + + // Map back to original vocab ids if a candidates tensor is available. + struct ggml_tensor * sampled_token = idx; + if (ggml_data->candidates != nullptr) { + struct ggml_tensor * candidates = ggml_data->candidates; + struct ggml_tensor * candidates_reshaped = ggml_view_2d(ctx, candidates, 1, ggml_nelements(candidates), + ggml_type_size(candidates->type), 0); + + sampled_token = ggml_get_rows(ctx, candidates_reshaped, idx); + ggml_set_name(sampled_token, "dist_sampled_token"); + } + + ggml_set_output(sampled_token); + ggml_data->sampled = sampled_token; +} + +static void llama_sampler_dist_set_input_ggml(struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_dist *) smpl->ctx; + GGML_ASSERT(sctx->inp_uniform != nullptr); + + std::uniform_real_distribution dist(0.0f, 1.0f); + const float rnd = dist(sctx->rng); + ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float)); +} + +static void llama_sampler_dist_init_ggml( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_dist *) smpl->ctx; + + sctx->device = ggml_backend_buft_get_device(buft); + + ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + sctx->inp_ctx.reset(ggml_init(params)); + + // Create the uniform random scalar input tensor. This will be set by + // llama_sampler_dist_set_input_ggml after this graph is built. + sctx->inp_uniform = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1); + ggml_set_name(sctx->inp_uniform, "uniform"); + ggml_set_input(sctx->inp_uniform); + + // Allocate all tensors from our context to the backend + sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft)); +} + static struct llama_sampler_i llama_sampler_dist_i = { - /* .name = */ llama_sampler_dist_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_dist_apply, - /* .reset = */ llama_sampler_dist_reset, - /* .clone = */ llama_sampler_dist_clone, - /* .free = */ llama_sampler_dist_free, - /* .apply_ggml = */ nullptr, - /* .accept_ggml = */ nullptr, - /* .set_input_ggml = */ nullptr, - /* .set_backend_context = */ nullptr, + /* .name = */ llama_sampler_dist_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_dist_apply, + /* .reset = */ llama_sampler_dist_reset, + /* .clone = */ llama_sampler_dist_clone, + /* .free = */ llama_sampler_dist_free, + /* .apply_ggml = */ llama_sampler_dist_apply_ggml, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ llama_sampler_dist_set_input_ggml, + /* .init_ggml = */ llama_sampler_dist_init_ggml, }; struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { @@ -834,9 +948,13 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { return llama_sampler_init( /* .iface = */ &llama_sampler_dist_i, /* .ctx = */ new llama_sampler_dist { - /* .seed = */ seed, - /* .seed_cur = */ seed_cur, - /* .rng = */ std::mt19937(seed_cur), + /* .seed = */ seed, + /* .seed_cur = */ seed_cur, + /* .rng = */ std::mt19937(seed_cur), + /* .device = */ nullptr, + /* .inp_uniform = */ nullptr, + /* .inp_ctx = */ nullptr, + /* .inp_buf = */ nullptr, } ); } @@ -845,6 +963,9 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { struct llama_sampler_top_k { const int32_t k; + + // Only required for checking operation support and can be removed later. + ggml_backend_dev_t device; }; static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl*/) { @@ -865,24 +986,60 @@ static void llama_sampler_top_k_free(struct llama_sampler * smpl) { delete (llama_sampler_top_k *) smpl->ctx; } +static void llama_sampler_top_k_apply_ggml( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_ggml_data * ggml_data) { + + auto * ctx_data = (llama_sampler_top_k *) smpl->ctx; + + struct ggml_tensor * top_k = ggml_top_k(ctx, ggml_data->logits, ctx_data->k); + ggml_set_name(top_k, "top_k"); + + // top_k is a view of argsort - check if backend supports the underlying argsort operation + // by checking the source tensor (which is the argsort result) + if (ctx_data->device && top_k->src[0] && !ggml_backend_dev_supports_op(ctx_data->device, top_k->src[0])) { + fprintf(stderr, "Warning: backend does not support argsort operation required for top-k sampling\n"); + fprintf(stderr, "CPU backend will be used instead which defeats the purpose of having backend samplers\n"); + } + + ggml_data->candidates = top_k; + + struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, ggml_data->logits, 1, ggml_data->logits->ne[0]); + struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k); + ggml_set_name(top_k_rows, "top_k_rows"); + + ggml_data->logits = ggml_reshape_1d(ctx, top_k_rows, ctx_data->k); + ggml_build_forward_expand(gf, ggml_data->logits); +} + +static void llama_sampler_top_k_init_ggml( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * ctx_data = (llama_sampler_top_k *) smpl->ctx; + ctx_data->device = ggml_backend_buft_get_device(buft); +} + static struct llama_sampler_i llama_sampler_top_k_i = { - /* .name = */ llama_sampler_top_k_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_top_k_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_top_k_clone, - /* .free = */ llama_sampler_top_k_free, - /* .apply_ggml = */ nullptr, - /* .accept_ggml = */ nullptr, - /* .set_input_ggml = */ nullptr, - /* .set_backend_context = */ nullptr, + /* .name = */ llama_sampler_top_k_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_k_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_k_clone, + /* .free = */ llama_sampler_top_k_free, + /* .apply_ggml = */ llama_sampler_top_k_apply_ggml, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .init_ggml = */ llama_sampler_top_k_init_ggml, }; struct llama_sampler * llama_sampler_init_top_k(int32_t k) { return llama_sampler_init( /* .iface = */ &llama_sampler_top_k_i, /* .ctx = */ new llama_sampler_top_k { - /* .k = */ k, + /* .k = */ k, + /* .device = */ nullptr, } ); } @@ -894,6 +1051,9 @@ struct llama_sampler_top_p { const size_t min_keep; std::vector buf_sort; + + // Only required for checking operation support and can be removed later. + ggml_backend_dev_t device; }; static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) { @@ -964,17 +1124,87 @@ static void llama_sampler_top_p_free(struct llama_sampler * smpl) { delete (llama_sampler_top_p *) smpl->ctx; } +static void llama_sampler_top_p_apply_ggml( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_ggml_data * ggml_data) { + auto * sctx = (llama_sampler_top_p *) smpl->ctx; + + struct ggml_tensor * softmax = ggml_soft_max(ctx, ggml_data->logits); + ggml_set_name(softmax, "top_p_softmax"); + + // Get the sorted indices of the softmax probabilities in descending order. + struct ggml_tensor * sorted_idx = ggml_argsort(ctx, softmax, GGML_SORT_ORDER_DESC); + ggml_set_name(sorted_idx, "top_p_sorted_idx"); + + // Do the sorting via reshape + get_rows + struct ggml_tensor * softmax_reshaped = ggml_reshape_2d(ctx, softmax, 1, softmax->ne[0]); + ggml_set_name(softmax_reshaped, "top_p_softmax_reshaped"); + + struct ggml_tensor * sorted_probs = ggml_get_rows(ctx, softmax_reshaped, sorted_idx); + ggml_set_name(sorted_probs, "top_p_sorted_probs"); + + struct ggml_tensor * sorted_probs_reshaped = ggml_reshape_2d(ctx, sorted_probs, softmax->ne[0], 1); + ggml_set_name(sorted_probs_reshaped, "top_p_sorted_probs_reshaped"); + // Compute Cumulative Distribution Function (CDF) by means of GGML_OP_CUMSUM. + struct ggml_tensor * sorted_cdf = ggml_cumsum(ctx, sorted_probs_reshaped); + ggml_set_name(sorted_cdf, "top_p_sorted_cdf"); + + // Invert CDF and add top-p value so that ggml_step yields 1 for values we want to keep + struct ggml_tensor * sorted_cdf_scaled = ggml_scale_bias(ctx, sorted_cdf, -1.0f, sctx->p); + ggml_set_name(sorted_cdf_scaled, "top_p_sorted_cdf_scaled"); + + struct ggml_tensor * sorted_mask = ggml_step(ctx, sorted_cdf_scaled); + ggml_set_name(sorted_mask, "top_p_sorted_mask"); + + // reverse sorting by argsort(argsort) + // cast to F32 since cuda only supports float inputs + struct ggml_tensor * reverse_argsort = ggml_argsort(ctx, ggml_cast(ctx, sorted_idx, GGML_TYPE_F32), GGML_SORT_ORDER_ASC); + ggml_set_name(reverse_argsort, "top_p_reverse_argsort"); + + // Do the sorting via reshape + get_rows + struct ggml_tensor * sorted_reshaped_mask = ggml_reshape_2d(ctx, sorted_mask, 1, sorted_mask->ne[0]); + ggml_set_name(sorted_reshaped_mask, "top_p_sorted_reshaped_mask"); + + struct ggml_tensor * reshaped_mask = ggml_get_rows(ctx, sorted_reshaped_mask, reverse_argsort); + ggml_set_name(reshaped_mask, "top_p_reshaped_mask"); + + struct ggml_tensor * mask = ggml_reshape_2d(ctx, reshaped_mask, sorted_mask->ne[0], 1); + ggml_set_name(mask, "top_p_mask"); + + // Use ggml_scale_bias (output = (a * s) + b) which in this case becomes: + // top_p_bias = (mask * 1e9f) - 1e9f. + // So entries in the mask that we want to discard will become -1e9f, and + // others will be 0 (meaning that will not effect the logits). + const float large_val = 1e9f; + struct ggml_tensor * top_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val); + ggml_set_name(top_p_bias, "top_p_bias"); + + ggml_data->logits = ggml_add(ctx, ggml_data->logits, top_p_bias); + ggml_set_name(ggml_data->logits, "top_p_logits"); + + ggml_build_forward_expand(gf, ggml_data->logits); +} + +static void llama_sampler_top_p_init_ggml( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_top_p *) smpl->ctx; + sctx->device = ggml_backend_buft_get_device(buft); +} + static struct llama_sampler_i llama_sampler_top_p_i = { - /* .name = */ llama_sampler_top_p_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_top_p_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_top_p_clone, - /* .free = */ llama_sampler_top_p_free, - /* .apply_ggml = */ nullptr, - /* .accept_ggml = */ nullptr, - /* .set_input_ggml = */ nullptr, - /* .set_backend_context = */ nullptr, + /* .name = */ llama_sampler_top_p_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_p_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_p_clone, + /* .free = */ llama_sampler_top_p_free, + /* .apply_ggml = */ llama_sampler_top_p_apply_ggml, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .init_ggml = */ llama_sampler_top_p_init_ggml, }; struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) { @@ -984,6 +1214,7 @@ struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) { /* .p = */ p, /* .min_keep = */ min_keep, /* .buf_sort = */ {}, + /* .device = */ nullptr, } ); } @@ -993,6 +1224,9 @@ struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) { struct llama_sampler_min_p { const float p; const size_t min_keep; + + // Only required for checking operation support and can be removed later. + ggml_backend_dev_t device; }; static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl*/) { @@ -1062,17 +1296,67 @@ static void llama_sampler_min_p_free(struct llama_sampler * smpl) { delete (llama_sampler_min_p *) smpl->ctx; } +static void llama_sampler_min_p_apply_ggml( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_ggml_data * ggml_data) { + auto * sctx = (llama_sampler_min_p *) smpl->ctx; + + struct ggml_tensor * max_idx = ggml_argmax(ctx, ggml_data->logits); + ggml_set_name(max_idx, "max_idx"); + + struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, ggml_data->logits, 1, ggml_data->logits->ne[0]); + ggml_set_name(logits_rows, "logits_rows"); + + struct ggml_tensor * max_logit = ggml_get_rows(ctx, logits_rows, max_idx); + ggml_set_name(max_logit, "max_logit"); + + // Calculate the threshold value. + struct ggml_tensor * threshold = ggml_scale_bias(ctx, max_logit, 1.0f, logf(sctx->p)); + ggml_set_name(threshold, "min_p_threshold"); + + // Subtract the threshold from logits. + struct ggml_tensor * sub = ggml_sub(ctx, ggml_data->logits, threshold); + + // Create a mask where logits below the threshold are 0 (discard), + // and others are 1 (keep). + struct ggml_tensor * mask = ggml_step(ctx, sub); + ggml_set_name(mask, "min_p_mask"); + + // Use ggml_scale_bias (output = (a * s) + b) which in this case becomes: + // min_p_bias = (mask * 1e9f) - 1e9f. + // So entries in the mask that we want to discard will become -1e9f, and + // others will be 0 (meaning that will not effect the logits). + const float large_val = 1e9f; + struct ggml_tensor * min_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val); + ggml_set_name(min_p_bias, "min_p_bias"); + + // Add the min_p bias to the logits. + ggml_data->logits = ggml_add(ctx, ggml_data->logits, min_p_bias); + ggml_set_name(ggml_data->logits, "min_p_logits"); + + ggml_build_forward_expand(gf, ggml_data->logits); +} + +static void llama_sampler_min_p_init_ggml( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_min_p *) smpl->ctx; + sctx->device = ggml_backend_buft_get_device(buft); +} + static struct llama_sampler_i llama_sampler_min_p_i = { - /* .name = */ llama_sampler_min_p_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_min_p_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_min_p_clone, - /* .free = */ llama_sampler_min_p_free, - /* .apply_ggml = */ nullptr, - /* .accept_ggml = */ nullptr, - /* .set_input_ggml = */ nullptr, - /* .set_backend_context = */ nullptr, + /* .name = */ llama_sampler_min_p_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_min_p_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_min_p_clone, + /* .free = */ llama_sampler_min_p_free, + /* .apply_ggml = */ llama_sampler_min_p_apply_ggml, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .init_ggml = */ llama_sampler_min_p_init_ggml, }; struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) { @@ -1081,6 +1365,7 @@ struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) { /* .ctx = */ new llama_sampler_min_p { /* .p = */ p, /* .min_keep = */ min_keep, + /* .device = */ nullptr, } ); } @@ -1166,16 +1451,16 @@ static void llama_sampler_typical_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_typical_i = { - /* .name = */ llama_sampler_typical_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_typical_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_typical_clone, - /* .free = */ llama_sampler_typical_free, - /* .apply_ggml = */ nullptr, - /* .accept_ggml = */ nullptr, - /* .set_input_ggml = */ nullptr, - /* .set_backend_context = */ nullptr, + /* .name = */ llama_sampler_typical_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_typical_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_typical_clone, + /* .free = */ llama_sampler_typical_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .init_ggml = */ nullptr, }; struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) { @@ -1213,17 +1498,38 @@ static void llama_sampler_temp_free(struct llama_sampler * smpl) { delete (llama_sampler_temp *) smpl->ctx; } +static void llama_sampler_temp_apply_ggml( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_ggml_data * ggml_data) { + auto * ctx_data = (llama_sampler_temp *) smpl->ctx; + + if (ctx_data->temp <= 0.0f) { + return; + } + + struct ggml_tensor * scaled = ggml_scale(ctx, ggml_data->logits, 1.0f / ctx_data->temp); + ggml_set_name(scaled, "temp_scaled"); + + // Make sure the scaled tensor is contiguous for subsequent operations + ggml_data->logits = ggml_cont(ctx, scaled); + ggml_set_name(ggml_data->logits, "temp_scaled_logits"); + + ggml_build_forward_expand(gf, ggml_data->logits); +} + static struct llama_sampler_i llama_sampler_temp_i = { - /* .name = */ llama_sampler_temp_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_temp_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_temp_clone, - /* .free = */ llama_sampler_temp_free, - /* .apply_ggml = */ nullptr, - /* .accept_ggml = */ nullptr, - /* .set_input_ggml = */ nullptr, - /* .set_backend_context = */ nullptr, + /* .name = */ llama_sampler_temp_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_temp_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_temp_clone, + /* .free = */ llama_sampler_temp_free, + /* .apply_ggml = */ llama_sampler_temp_apply_ggml, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .init_ggml = */ nullptr, }; struct llama_sampler * llama_sampler_init_temp(float temp) { @@ -1328,16 +1634,16 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_temp_ext_i = { - /* .name = */ llama_sampler_temp_ext_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_temp_ext_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_temp_ext_clone, - /* .free = */ llama_sampler_temp_ext_free, - /* .apply_ggml = */ nullptr, - /* .accept_ggml = */ nullptr, - /* .set_input_ggml = */ nullptr, - /* .set_backend_context = */ nullptr, + /* .name = */ llama_sampler_temp_ext_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_temp_ext_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_temp_ext_clone, + /* .free = */ llama_sampler_temp_ext_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .init_ggml = */ nullptr, }; struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) { @@ -1426,16 +1732,16 @@ static void llama_sampler_xtc_reset(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_xtc_i = { - /* .name = */ llama_sampler_xtc_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sample_xtc_apply, - /* .reset = */ llama_sampler_xtc_reset, - /* .clone = */ llama_sampler_xtc_clone, - /* .free = */ llama_sampler_xtc_free, - /* .apply_ggml = */ nullptr, - /* .accept_ggml = */ nullptr, - /* .set_input_ggml = */ nullptr, - /* .set_backend_context = */ nullptr, + /* .name = */ llama_sampler_xtc_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sample_xtc_apply, + /* .reset = */ llama_sampler_xtc_reset, + /* .clone = */ llama_sampler_xtc_clone, + /* .free = */ llama_sampler_xtc_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .init_ggml = */ nullptr, }; struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) { @@ -1538,16 +1844,16 @@ static void llama_sampler_mirostat_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_mirostat_i = { - /* .name = */ llama_sampler_mirostat_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_mirostat_apply, - /* .reset = */ llama_sampler_mirostat_reset, - /* .clone = */ llama_sampler_mirostat_clone, - /* .free = */ llama_sampler_mirostat_free, - /* .apply_ggml = */ nullptr, - /* .accept_ggml = */ nullptr, - /* .set_input_ggml = */ nullptr, - /* .set_backend_context = */ nullptr, + /* .name = */ llama_sampler_mirostat_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_mirostat_apply, + /* .reset = */ llama_sampler_mirostat_reset, + /* .clone = */ llama_sampler_mirostat_clone, + /* .free = */ llama_sampler_mirostat_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .init_ggml = */ nullptr, }; struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) { @@ -1641,16 +1947,16 @@ static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_mirostat_v2_i = { - /* .name = */ llama_sampler_mirostat_v2_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_mirostat_v2_apply, - /* .reset = */ llama_sampler_mirostat_v2_reset, - /* .clone = */ llama_sampler_mirostat_v2_clone, - /* .free = */ llama_sampler_mirostat_v2_free, - /* .apply_ggml = */ nullptr, - /* .accept_ggml = */ nullptr, - /* .set_input_ggml = */ nullptr, - /* .set_backend_context = */ nullptr, + /* .name = */ llama_sampler_mirostat_v2_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_mirostat_v2_apply, + /* .reset = */ llama_sampler_mirostat_v2_reset, + /* .clone = */ llama_sampler_mirostat_v2_clone, + /* .free = */ llama_sampler_mirostat_v2_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .init_ggml = */ nullptr, }; struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) { @@ -1762,16 +2068,16 @@ static void llama_sampler_grammar_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_grammar_i = { - /* .name = */ llama_sampler_grammar_name, - /* .accept = */ llama_sampler_grammar_accept_impl, - /* .apply = */ llama_sampler_grammar_apply, - /* .reset = */ llama_sampler_grammar_reset, - /* .clone = */ llama_sampler_grammar_clone, - /* .free = */ llama_sampler_grammar_free, - /* .apply_ggml = */ nullptr, - /* .accept_ggml = */ nullptr, - /* .set_input_ggml = */ nullptr, - /* .set_backend_context = */ nullptr, + /* .name = */ llama_sampler_grammar_name, + /* .accept = */ llama_sampler_grammar_accept_impl, + /* .apply = */ llama_sampler_grammar_apply, + /* .reset = */ llama_sampler_grammar_reset, + /* .clone = */ llama_sampler_grammar_clone, + /* .free = */ llama_sampler_grammar_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .init_ggml = */ nullptr, }; static struct llama_sampler * llama_sampler_init_grammar_impl( @@ -1973,16 +2279,16 @@ static void llama_sampler_penalties_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_penalties_i = { - /* .name = */ llama_sampler_penalties_name, - /* .accept = */ llama_sampler_penalties_accept, - /* .apply = */ llama_sampler_penalties_apply, - /* .reset = */ llama_sampler_penalties_reset, - /* .clone = */ llama_sampler_penalties_clone, - /* .free = */ llama_sampler_penalties_free, - /* .apply_ggml = */ nullptr, - /* .accept_ggml = */ nullptr, - /* .set_input_ggml = */ nullptr, - /* .set_backend_context = */ nullptr, + /* .name = */ llama_sampler_penalties_name, + /* .accept = */ llama_sampler_penalties_accept, + /* .apply = */ llama_sampler_penalties_apply, + /* .reset = */ llama_sampler_penalties_reset, + /* .clone = */ llama_sampler_penalties_clone, + /* .free = */ llama_sampler_penalties_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .init_ggml = */ nullptr, }; struct llama_sampler * llama_sampler_init_penalties( @@ -2068,16 +2374,16 @@ static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_top_n_sigma_i = { - /* .name = */ llama_sampler_top_n_sigma_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_top_n_sigma_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_top_n_sigma_clone, - /* .free = */ llama_sampler_top_n_sigma_free, - /* .apply_ggml = */ nullptr, - /* .accept_ggml = */ nullptr, - /* .set_input_ggml = */ nullptr, - /* .set_backend_context = */ nullptr, + /* .name = */ llama_sampler_top_n_sigma_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_n_sigma_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_n_sigma_clone, + /* .free = */ llama_sampler_top_n_sigma_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .init_ggml = */ nullptr, }; struct llama_sampler * llama_sampler_init_top_n_sigma(float n) { @@ -2402,16 +2708,16 @@ static void llama_sampler_dry_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_dry_i = { - /* .name = */ llama_sampler_dry_name, - /* .accept = */ llama_sampler_dry_accept, - /* .apply = */ llama_sampler_dry_apply, - /* .reset = */ llama_sampler_dry_reset, - /* .clone = */ llama_sampler_dry_clone, - /* .free = */ llama_sampler_dry_free, - /* .apply_ggml = */ nullptr, - /* .accept_ggml = */ nullptr, - /* .set_input_ggml = */ nullptr, - /* .set_backend_context = */ nullptr, + /* .name = */ llama_sampler_dry_name, + /* .accept = */ llama_sampler_dry_accept, + /* .apply = */ llama_sampler_dry_apply, + /* .reset = */ llama_sampler_dry_reset, + /* .clone = */ llama_sampler_dry_clone, + /* .free = */ llama_sampler_dry_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .init_ggml = */ nullptr, }; struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t n_ctx_train, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) { @@ -2498,6 +2804,11 @@ struct llama_sampler_logit_bias { const std::vector logit_bias; std::vector to_search; + + struct ggml_tensor * inp_logit_bias; + + ggml_context_ptr inp_ctx; + ggml_backend_buffer_ptr inp_buf; }; static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /*smpl*/) { @@ -2546,17 +2857,77 @@ static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) { delete (llama_sampler_logit_bias *) smpl->ctx; } +static void llama_sampler_logit_bias_apply_ggml( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_ggml_data * ggml_data) { + GGML_UNUSED(gf); + GGML_UNUSED(ctx); + + auto * sctx = (llama_sampler_logit_bias *) smpl->ctx; + if (sctx->logit_bias.empty()) { + return; + } + + // Add the sparse logit logit_bias to the logits + struct ggml_tensor * logit_biased = ggml_add_inplace(ctx, ggml_data->logits, sctx->inp_logit_bias); + ggml_build_forward_expand(gf, logit_biased); +} + +static void llama_sampler_logit_bias_set_input_ggml(struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_logit_bias *) smpl->ctx; + if (sctx->logit_bias.empty()) { + return; + } + GGML_ASSERT(sctx->inp_logit_bias != nullptr); + + // Create a sparse logit_bias vector from the logit_bias entries. + std::vector logit_bias_sparse(sctx->n_vocab, 0.0f); + for (const auto & lb : sctx->logit_bias) { + GGML_ASSERT(lb.token >= 0 && lb.token < (int32_t) sctx->n_vocab); + logit_bias_sparse[lb.token] = lb.bias; + } + + ggml_backend_tensor_set(sctx->inp_logit_bias, logit_bias_sparse.data(), 0, ggml_nbytes(sctx->inp_logit_bias)); +} + +static void llama_sampler_logit_bias_init_ggml( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_logit_bias *) smpl->ctx; + + if (sctx->logit_bias.empty()) { + return; + } + + ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + sctx->inp_ctx.reset(ggml_init(params)); + + sctx->inp_logit_bias = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, sctx->n_vocab); + ggml_set_name(sctx->inp_logit_bias, "logit_bias"); + ggml_set_input(sctx->inp_logit_bias); + + // Allocate all tensors from our context to the backend + sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft)); +} + static struct llama_sampler_i llama_sampler_logit_bias_i = { - /* .name = */ llama_sampler_logit_bias_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_logit_bias_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_logit_bias_clone, - /* .free = */ llama_sampler_logit_bias_free, - /* .apply_ggml = */ nullptr, - /* .accept_ggml = */ nullptr, - /* .set_input_ggml = */ nullptr, - /* .set_backend_context = */ nullptr, + /* .name = */ llama_sampler_logit_bias_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_logit_bias_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_logit_bias_clone, + /* .free = */ llama_sampler_logit_bias_free, + /* .apply_ggml = */ llama_sampler_logit_bias_apply_ggml, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ llama_sampler_logit_bias_set_input_ggml, + /* .init_ggml = */ llama_sampler_logit_bias_init_ggml, }; struct llama_sampler * llama_sampler_init_logit_bias( @@ -2566,9 +2937,12 @@ struct llama_sampler * llama_sampler_init_logit_bias( return llama_sampler_init( /* .iface = */ &llama_sampler_logit_bias_i, /* .ctx = */ new llama_sampler_logit_bias { - /* .n_vocab = */ n_vocab, - /* .logit_bias = */ std::vector(logit_bias, logit_bias + n_logit_bias), - /* .to_search = */ {}, + /* .n_vocab = */ n_vocab, + /* .logit_bias = */ std::vector(logit_bias, logit_bias + n_logit_bias), + /* .to_search = */ {}, + /* .inp_logit_bias = */ nullptr, + /* .inp_ctx = */ nullptr, + /* .inp_buf = */ nullptr, } ); } @@ -2781,16 +3155,16 @@ static void llama_sampler_infill_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_infill_i = { - /* .name = */ llama_sampler_infill_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_infill_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_infill_clone, - /* .free = */ llama_sampler_infill_free, - /* .apply_ggml = */ nullptr, - /* .accept_ggml = */ nullptr, - /* .set_input_ggml = */ nullptr, - /* .set_backend_context = */ nullptr, + /* .name = */ llama_sampler_infill_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_infill_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_infill_clone, + /* .free = */ llama_sampler_infill_free, + /* .apply_ggml = */ nullptr, + /* .accept_ggml = */ nullptr, + /* .set_input_ggml = */ nullptr, + /* .init_ggml = */ nullptr, }; struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) { diff --git a/tests/test-backend-sampler.cpp b/tests/test-backend-sampler.cpp index 47d2a139ea..d6839c8805 100644 --- a/tests/test-backend-sampler.cpp +++ b/tests/test-backend-sampler.cpp @@ -277,7 +277,7 @@ static void test_backend_greedy_sampling(const char * model_path) { struct llama_sampler_chain_params backend_sampler_params = llama_sampler_chain_default_params(); struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_sampler_params); - llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_greedy()); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_greedy()); std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; if (!test_ctx.setup(model_path, backend_sampler_configs)) { @@ -315,7 +315,7 @@ static void test_backend_top_k_sampling(const char * model_path) { const int32_t k = 8; struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); - llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_top_k(k)); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_top_k(k)); std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; if (!test_ctx.setup(model_path, backend_sampler_configs)) { @@ -363,12 +363,12 @@ static void test_backend_temp_sampling(const char * model_path) { const float temp_0 = 0.8f; struct llama_sampler_chain_params backend_chain_params_0 = llama_sampler_chain_default_params(); struct llama_sampler * backend_sampler_chain_0 = llama_sampler_chain_init(backend_chain_params_0); - llama_sampler_chain_add(backend_sampler_chain_0, llama_sampler_backend_init_temp(temp_0)); + llama_sampler_chain_add(backend_sampler_chain_0, llama_sampler_init_temp(temp_0)); const float temp_1 = 0.1f; struct llama_sampler_chain_params backend_chain_params_1 = llama_sampler_chain_default_params(); struct llama_sampler * backend_sampler_chain_1 = llama_sampler_chain_init(backend_chain_params_1); - llama_sampler_chain_add(backend_sampler_chain_1, llama_sampler_backend_init_temp(temp_1)); + llama_sampler_chain_add(backend_sampler_chain_1, llama_sampler_init_temp(temp_1)); std::vector backend_sampler_configs = { { 0, backend_sampler_chain_0 }, @@ -430,7 +430,7 @@ static void test_backend_min_p_sampling(const char * model_path) { const float p = 0.1; struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); - llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_min_p(p)); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_min_p(p, 0)); std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; if (!test_ctx.setup(model_path, backend_sampler_configs)) { @@ -488,7 +488,7 @@ static void test_backend_top_p_sampling(const char * model_path) { const float p = 0.9; struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); - llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_top_p(p)); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_top_p(p, 0)); std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; if (!test_ctx.setup(model_path, backend_sampler_configs)) { @@ -541,12 +541,12 @@ static void test_backend_multi_sequence_sampling(const char * model_path) { struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params(); struct llama_sampler * sampler_chain_0 = llama_sampler_chain_init(chain_params_0); - llama_sampler_chain_add(sampler_chain_0, llama_sampler_backend_init_greedy()); + llama_sampler_chain_add(sampler_chain_0, llama_sampler_init_greedy()); struct llama_sampler_chain_params chain_params_1 = llama_sampler_chain_default_params(); struct llama_sampler * sampler_chain_1 = llama_sampler_chain_init(chain_params_1); - llama_sampler_chain_add(sampler_chain_1, llama_sampler_backend_init_temp(0.8f)); - llama_sampler_chain_add(sampler_chain_1, llama_sampler_backend_init_greedy()); + llama_sampler_chain_add(sampler_chain_1, llama_sampler_init_temp(0.8f)); + llama_sampler_chain_add(sampler_chain_1, llama_sampler_init_greedy()); std::vector backend_sampler_configs = { { 0, sampler_chain_0 }, @@ -613,7 +613,7 @@ static void test_backend_dist_sampling(const char * model_path) { const int32_t seed = 88; struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); - llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_dist(seed)); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_dist(seed)); std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; if (!test_ctx.setup(model_path, backend_sampler_configs)) { @@ -642,7 +642,7 @@ static void test_backend_dist_sampling_and_cpu(const char * model_path) { const int32_t seed = 88; struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); - llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_dist(seed)); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_dist(seed)); std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; if (!test_ctx.setup(model_path, backend_sampler_configs)) { @@ -689,11 +689,11 @@ static void test_backend_logit_bias_sampling(const char * model_path) { struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); - llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_logit_bias( + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_logit_bias( llama_vocab_n_tokens(test_ctx.vocab), logit_bias.size(), logit_bias.data())); - llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_dist(88)); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_dist(88)); std::vector backend_sampler_configs = { { seq_id, backend_sampler_chain }, @@ -720,12 +720,12 @@ static void test_backend_mixed_sampling(const char * model_path) { struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params(); struct llama_sampler * sampler_chain_0 = llama_sampler_chain_init(chain_params_0); - llama_sampler_chain_add(sampler_chain_0, llama_sampler_backend_init_dist(88)); + llama_sampler_chain_add(sampler_chain_0, llama_sampler_init_dist(88)); int k = 40; struct llama_sampler_chain_params chain_params_1 = llama_sampler_chain_default_params(); struct llama_sampler * sampler_chain_1 = llama_sampler_chain_init(chain_params_1); - llama_sampler_chain_add(sampler_chain_1, llama_sampler_backend_init_top_k(k)); + llama_sampler_chain_add(sampler_chain_1, llama_sampler_init_top_k(k)); std::vector backend_sampler_configs = { { 0, sampler_chain_0 }, @@ -776,7 +776,7 @@ static void test_backend_set_sampler(const char * model_path) { const int seq_id = 0; struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); - llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_dist(seed)); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_dist(seed)); std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; if (!test_ctx.setup(model_path, backend_sampler_configs)) { @@ -822,8 +822,8 @@ static void test_backend_set_sampler(const char * model_path) { // Set a new backend sampler for the sequence. struct llama_sampler_chain_params new_backend_chain_params = llama_sampler_chain_default_params(); struct llama_sampler * new_backend_sampler_chain = llama_sampler_chain_init(new_backend_chain_params); - llama_sampler_chain_add(new_backend_sampler_chain, llama_sampler_backend_init_top_k(20)); - llama_sampler_chain_add(new_backend_sampler_chain, llama_sampler_backend_init_dist(seed)); + llama_sampler_chain_add(new_backend_sampler_chain, llama_sampler_init_top_k(20)); + llama_sampler_chain_add(new_backend_sampler_chain, llama_sampler_init_dist(seed)); llama_set_backend_sampler(test_ctx.ctx, seq_id, new_backend_sampler_chain); if (!test_ctx.decode_tokens(tokens2)) { @@ -841,7 +841,7 @@ static void test_backend_cpu_mixed_batch(const char * model_path) { // Sequence 0 uses backend sampling struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params(); struct llama_sampler * sampler_chain_0 = llama_sampler_chain_init(chain_params_0); - llama_sampler_chain_add(sampler_chain_0, llama_sampler_backend_init_dist(88)); + llama_sampler_chain_add(sampler_chain_0, llama_sampler_init_dist(88)); std::vector backend_sampler_configs = { { 0, sampler_chain_0 }, @@ -912,7 +912,7 @@ static void test_backend_cpu_mixed_batch(const char * model_path) { { struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); struct llama_sampler * sampler_chain= llama_sampler_chain_init(chain_params); - llama_sampler_chain_add(sampler_chain, llama_sampler_backend_init_dist(88)); + llama_sampler_chain_add(sampler_chain, llama_sampler_init_dist(88)); llama_set_backend_sampler(test_ctx.ctx, 0, sampler_chain); @@ -937,7 +937,7 @@ static void test_backend_max_outputs(const char * model_path) { const int32_t seed = 88; llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); - llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_dist(seed)); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_dist(seed)); std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; if (!test_ctx.setup(model_path, backend_sampler_configs)) {