diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index d96f619ae1..f3891453e4 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -421,39 +421,6 @@ void llama_sampler_free(struct llama_sampler * smpl) { delete smpl; } -llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) { - const auto * logits = llama_get_logits_ith(ctx, idx); - - const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_model_get_vocab(model); - - const int n_vocab = llama_vocab_n_tokens(vocab); - - // TODO: do not allocate each time - std::vector cur; - cur.reserve(n_vocab); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); - } - - llama_token_data_array cur_p = { - /* .data = */ cur.data(), - /* .size = */ cur.size(), - /* .selected = */ -1, - /* .sorted = */ false, - }; - - llama_sampler_apply(smpl, &cur_p); - - GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size); - - auto token = cur_p.data[cur_p.selected].id; - - llama_sampler_accept(smpl, token); - - return token; -} - // sampler chain static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) { @@ -527,12 +494,56 @@ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_param /* .ctx = */ new llama_sampler_chain { /* .params = */ params, /* .samplers = */ {}, + /* .cur = */ {}, /* .t_sample_us = */ 0, /* .n_sample = */ 0, } ); } +llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) { + const auto * logits = llama_get_logits_ith(ctx, idx); + + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const int n_vocab = llama_vocab_n_tokens(vocab); + + // use pre-allocated buffer from chain if available, otherwise allocate locally + std::vector * cur_ptr; + std::vector cur_local; + + if (smpl->iface == &llama_sampler_chain_i) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + cur_ptr = &chain->cur; + } else { + cur_ptr = &cur_local; + } + + auto & cur = *cur_ptr; + cur.resize(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } + + llama_token_data_array cur_p = { + /* .data = */ cur.data(), + /* .size = */ cur.size(), + /* .selected = */ -1, + /* .sorted = */ false, + }; + + llama_sampler_apply(smpl, &cur_p); + + GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size); + + auto token = cur_p.data[cur_p.selected].id; + + llama_sampler_accept(smpl, token); + + return token; +} + void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) { auto * p = (llama_sampler_chain *) chain->ctx; p->samplers.push_back(smpl); diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 759dd7dcb7..1e3de4e2ec 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -16,6 +16,9 @@ struct llama_sampler_chain { std::vector samplers; + // pre-allocated buffer for llama_sampler_sample to avoid repeated allocations + std::vector cur; + // timing mutable int64_t t_sample_us;