sampling: reuse token data buffer in llama_sampler_sample (#18365)
* sampling: reuse token data buffer in llama_sampler_sample * move cur buffer before timing section, after samplers * minor : fix build --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
f14f4e421b
commit
c32fa21db8
|
|
@ -421,39 +421,6 @@ void llama_sampler_free(struct llama_sampler * smpl) {
|
|||
delete smpl;
|
||||
}
|
||||
|
||||
llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
|
||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
const int n_vocab = llama_vocab_n_tokens(vocab);
|
||||
|
||||
// TODO: do not allocate each time
|
||||
std::vector<llama_token_data> cur;
|
||||
cur.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = {
|
||||
/* .data = */ cur.data(),
|
||||
/* .size = */ cur.size(),
|
||||
/* .selected = */ -1,
|
||||
/* .sorted = */ false,
|
||||
};
|
||||
|
||||
llama_sampler_apply(smpl, &cur_p);
|
||||
|
||||
GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
|
||||
|
||||
auto token = cur_p.data[cur_p.selected].id;
|
||||
|
||||
llama_sampler_accept(smpl, token);
|
||||
|
||||
return token;
|
||||
}
|
||||
|
||||
// sampler chain
|
||||
|
||||
static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
|
||||
|
|
@ -527,12 +494,56 @@ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_param
|
|||
/* .ctx = */ new llama_sampler_chain {
|
||||
/* .params = */ params,
|
||||
/* .samplers = */ {},
|
||||
/* .cur = */ {},
|
||||
/* .t_sample_us = */ 0,
|
||||
/* .n_sample = */ 0,
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
|
||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
const int n_vocab = llama_vocab_n_tokens(vocab);
|
||||
|
||||
// use pre-allocated buffer from chain if available, otherwise allocate locally
|
||||
std::vector<llama_token_data> * cur_ptr;
|
||||
std::vector<llama_token_data> cur_local;
|
||||
|
||||
if (smpl->iface == &llama_sampler_chain_i) {
|
||||
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
||||
cur_ptr = &chain->cur;
|
||||
} else {
|
||||
cur_ptr = &cur_local;
|
||||
}
|
||||
|
||||
auto & cur = *cur_ptr;
|
||||
cur.resize(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = {
|
||||
/* .data = */ cur.data(),
|
||||
/* .size = */ cur.size(),
|
||||
/* .selected = */ -1,
|
||||
/* .sorted = */ false,
|
||||
};
|
||||
|
||||
llama_sampler_apply(smpl, &cur_p);
|
||||
|
||||
GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
|
||||
|
||||
auto token = cur_p.data[cur_p.selected].id;
|
||||
|
||||
llama_sampler_accept(smpl, token);
|
||||
|
||||
return token;
|
||||
}
|
||||
|
||||
void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
|
||||
auto * p = (llama_sampler_chain *) chain->ctx;
|
||||
p->samplers.push_back(smpl);
|
||||
|
|
|
|||
|
|
@ -16,6 +16,9 @@ struct llama_sampler_chain {
|
|||
|
||||
std::vector<struct llama_sampler *> samplers;
|
||||
|
||||
// pre-allocated buffer for llama_sampler_sample to avoid repeated allocations
|
||||
std::vector<llama_token_data> cur;
|
||||
|
||||
// timing
|
||||
|
||||
mutable int64_t t_sample_us;
|
||||
|
|
|
|||
Loading…
Reference in New Issue