sampling : always expose sampled_ids
This commit precomputes and caches the full-vocab token id list in llama_context's constructor, so llama_get_backend_sampled_token_ids_ith always returns a valid pointer. The motivation for this is that this enables both common/sampling.cpp and src/llama-sampling.cpp can simplify their logic. Not all backends samplers that process logits need to set the sampled_tokens_id as they may not change the order of the logits, for example the temperature sampler only scales the logits but does not change their order. Simliar the logit bias sampler only adds bias to specific token ids but does not change the order of the logits. In these cases there will not be a device to host copy of the sampled token ids, and this is the use case where having this precomputed list is useful.
This commit is contained in:
parent
4b52e59903
commit
82957a90f2
|
|
@ -128,28 +128,14 @@ struct common_sampler {
|
|||
if (sampled_probs) {
|
||||
const uint32_t sampled_probs_count = llama_get_backend_sampled_probs_count_ith(ctx, idx);
|
||||
cur.reserve(sampled_probs_count);
|
||||
// The backend sampler has filtered the probabilities so we need to use the sampled ids.
|
||||
if (sampled_ids != nullptr) {
|
||||
for (uint32_t i = 0; i < sampled_probs_count; ++i) {
|
||||
cur.emplace_back(llama_token_data{sampled_ids[i], 0.0f, sampled_probs[i]});
|
||||
}
|
||||
} else {
|
||||
for (llama_token token_id = 0; token_id < (int) sampled_probs_count; token_id++) {
|
||||
cur.emplace_back(llama_token_data{token_id, 0.0f, sampled_probs[token_id]});
|
||||
}
|
||||
for (uint32_t i = 0; i < sampled_probs_count; ++i) {
|
||||
cur.emplace_back(llama_token_data{sampled_ids[i], 0.0f, sampled_probs[i]});
|
||||
}
|
||||
} else if (sampled_logits) {
|
||||
const uint32_t sampled_logits_count = llama_get_backend_sampled_logits_count_ith(ctx, idx);
|
||||
cur.reserve(sampled_logits_count);
|
||||
// The backend sampler has filtered the logits so we need to use the sampled ids.
|
||||
if (sampled_ids != nullptr) {
|
||||
for (uint32_t i = 0; i < sampled_logits_count; i++) {
|
||||
cur.emplace_back(llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f});
|
||||
}
|
||||
} else {
|
||||
for (llama_token token_id = 0; token_id < (int) sampled_logits_count; token_id++) {
|
||||
cur.emplace_back(llama_token_data{token_id, sampled_logits[token_id], 0.0f});
|
||||
}
|
||||
for (uint32_t i = 0; i < sampled_logits_count; i++) {
|
||||
cur.emplace_back(llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f});
|
||||
}
|
||||
} else {
|
||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||
|
|
|
|||
|
|
@ -430,6 +430,16 @@ llama_context::llama_context(
|
|||
LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg);
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize the full vocabulary token ids for backend samplers.
|
||||
{
|
||||
const llama_vocab * vocab = llama_model_get_vocab(&model);
|
||||
const int n_vocab = llama_vocab_n_tokens(vocab);
|
||||
sampled_token_ids_full_vocab.resize(n_vocab);
|
||||
for (int i = 0; i < n_vocab; ++i) {
|
||||
sampled_token_ids_full_vocab[i] = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llama_context::~llama_context() {
|
||||
|
|
@ -728,15 +738,18 @@ float * llama_context::get_backend_sampled_logits_ith(int32_t idx) {
|
|||
const llama_token * llama_context::get_backend_sampled_token_ids_ith(int32_t idx) {
|
||||
if (idx == -1) {
|
||||
if (sampled_token_ids_map.size() == 1) {
|
||||
return sampled_token_ids_map.begin()->second.data();
|
||||
const auto & vec = sampled_token_ids_map.begin()->second;
|
||||
if (!vec.empty()) {
|
||||
return vec.data();
|
||||
}
|
||||
}
|
||||
}
|
||||
auto it = sampled_token_ids_map.find(idx);
|
||||
if (it == sampled_token_ids_map.end() || it->second.empty()) {
|
||||
return nullptr;
|
||||
if (it != sampled_token_ids_map.end() && !it->second.empty()) {
|
||||
return it->second.data();
|
||||
}
|
||||
|
||||
return it->second.data();
|
||||
return sampled_token_ids_full_vocab.data();
|
||||
}
|
||||
|
||||
size_t llama_context::get_backend_sampled_logits_count(int32_t idx) const {
|
||||
|
|
|
|||
|
|
@ -263,6 +263,7 @@ private:
|
|||
|
||||
std::unordered_map<int32_t, std::vector<float>> sampled_logits_map;
|
||||
std::unordered_map<int32_t, std::vector<llama_token>> sampled_token_ids_map;
|
||||
std::vector<llama_token> sampled_token_ids_full_vocab;
|
||||
|
||||
// embeddings output (2-dimensional array: [n_outputs][n_embd])
|
||||
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
||||
|
|
|
|||
|
|
@ -460,28 +460,14 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte
|
|||
if (sampled_probs) {
|
||||
const uint32_t sampled_probs_count = llama_get_backend_sampled_probs_count_ith(ctx, idx);
|
||||
cur.reserve(sampled_probs_count);
|
||||
// The backend sampler has filtered the probabilities so we need to use the sampled ids.
|
||||
if (sampled_ids != nullptr) {
|
||||
for (uint32_t i = 0; i < sampled_probs_count; ++i) {
|
||||
cur.emplace_back(llama_token_data{sampled_ids[i], 0.0f, sampled_probs[i]});
|
||||
}
|
||||
} else {
|
||||
for (llama_token token_id = 0; token_id < (int) sampled_probs_count; token_id++) {
|
||||
cur.emplace_back(llama_token_data{token_id, 0.0f, sampled_probs[token_id]});
|
||||
}
|
||||
for (uint32_t i = 0; i < sampled_probs_count; ++i) {
|
||||
cur.emplace_back(llama_token_data{sampled_ids[i], 0.0f, sampled_probs[i]});
|
||||
}
|
||||
} else if (sampled_logits) {
|
||||
const uint32_t sampled_logits_count = llama_get_backend_sampled_logits_count_ith(ctx, idx);
|
||||
cur.reserve(sampled_logits_count);
|
||||
// The backend sampler has filtered the logits so we need to use the sampled ids.
|
||||
if (sampled_ids != nullptr) {
|
||||
for (llama_token i = 0; i < (int)sampled_logits_count; i++) {
|
||||
cur.emplace_back(llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f});
|
||||
}
|
||||
} else {
|
||||
for (llama_token token_id = 0; token_id < (int)sampled_logits_count; token_id++) {
|
||||
cur.emplace_back(llama_token_data{token_id, sampled_logits[token_id], 0.0f});
|
||||
}
|
||||
for (llama_token i = 0; i < (int)sampled_logits_count; i++) {
|
||||
cur.emplace_back(llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f});
|
||||
}
|
||||
} else {
|
||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||
|
|
|
|||
Loading…
Reference in New Issue