diff --git a/common/sampling.cpp b/common/sampling.cpp index ec61c18832..9c707a5bb9 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -128,28 +128,14 @@ struct common_sampler { if (sampled_probs) { const uint32_t sampled_probs_count = llama_get_backend_sampled_probs_count_ith(ctx, idx); cur.reserve(sampled_probs_count); - // The backend sampler has filtered the probabilities so we need to use the sampled ids. - if (sampled_ids != nullptr) { - for (uint32_t i = 0; i < sampled_probs_count; ++i) { - cur.emplace_back(llama_token_data{sampled_ids[i], 0.0f, sampled_probs[i]}); - } - } else { - for (llama_token token_id = 0; token_id < (int) sampled_probs_count; token_id++) { - cur.emplace_back(llama_token_data{token_id, 0.0f, sampled_probs[token_id]}); - } + for (uint32_t i = 0; i < sampled_probs_count; ++i) { + cur.emplace_back(llama_token_data{sampled_ids[i], 0.0f, sampled_probs[i]}); } } else if (sampled_logits) { const uint32_t sampled_logits_count = llama_get_backend_sampled_logits_count_ith(ctx, idx); cur.reserve(sampled_logits_count); - // The backend sampler has filtered the logits so we need to use the sampled ids. - if (sampled_ids != nullptr) { - for (uint32_t i = 0; i < sampled_logits_count; i++) { - cur.emplace_back(llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f}); - } - } else { - for (llama_token token_id = 0; token_id < (int) sampled_logits_count; token_id++) { - cur.emplace_back(llama_token_data{token_id, sampled_logits[token_id], 0.0f}); - } + for (uint32_t i = 0; i < sampled_logits_count; i++) { + cur.emplace_back(llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f}); } } else { const auto * logits = llama_get_logits_ith(ctx, idx); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index f931881c9c..25d3528434 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -430,6 +430,16 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg); } } + + // Initialize the full vocabulary token ids for backend samplers. + { + const llama_vocab * vocab = llama_model_get_vocab(&model); + const int n_vocab = llama_vocab_n_tokens(vocab); + sampled_token_ids_full_vocab.resize(n_vocab); + for (int i = 0; i < n_vocab; ++i) { + sampled_token_ids_full_vocab[i] = i; + } + } } llama_context::~llama_context() { @@ -728,15 +738,18 @@ float * llama_context::get_backend_sampled_logits_ith(int32_t idx) { const llama_token * llama_context::get_backend_sampled_token_ids_ith(int32_t idx) { if (idx == -1) { if (sampled_token_ids_map.size() == 1) { - return sampled_token_ids_map.begin()->second.data(); + const auto & vec = sampled_token_ids_map.begin()->second; + if (!vec.empty()) { + return vec.data(); + } } } auto it = sampled_token_ids_map.find(idx); - if (it == sampled_token_ids_map.end() || it->second.empty()) { - return nullptr; + if (it != sampled_token_ids_map.end() && !it->second.empty()) { + return it->second.data(); } - return it->second.data(); + return sampled_token_ids_full_vocab.data(); } size_t llama_context::get_backend_sampled_logits_count(int32_t idx) const { diff --git a/src/llama-context.h b/src/llama-context.h index b9020beff1..aba62e6e38 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -263,6 +263,7 @@ private: std::unordered_map> sampled_logits_map; std::unordered_map> sampled_token_ids_map; + std::vector sampled_token_ids_full_vocab; // embeddings output (2-dimensional array: [n_outputs][n_embd]) // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index dc9227c1a5..d210b826c7 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -460,28 +460,14 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte if (sampled_probs) { const uint32_t sampled_probs_count = llama_get_backend_sampled_probs_count_ith(ctx, idx); cur.reserve(sampled_probs_count); - // The backend sampler has filtered the probabilities so we need to use the sampled ids. - if (sampled_ids != nullptr) { - for (uint32_t i = 0; i < sampled_probs_count; ++i) { - cur.emplace_back(llama_token_data{sampled_ids[i], 0.0f, sampled_probs[i]}); - } - } else { - for (llama_token token_id = 0; token_id < (int) sampled_probs_count; token_id++) { - cur.emplace_back(llama_token_data{token_id, 0.0f, sampled_probs[token_id]}); - } + for (uint32_t i = 0; i < sampled_probs_count; ++i) { + cur.emplace_back(llama_token_data{sampled_ids[i], 0.0f, sampled_probs[i]}); } } else if (sampled_logits) { const uint32_t sampled_logits_count = llama_get_backend_sampled_logits_count_ith(ctx, idx); cur.reserve(sampled_logits_count); - // The backend sampler has filtered the logits so we need to use the sampled ids. - if (sampled_ids != nullptr) { - for (llama_token i = 0; i < (int)sampled_logits_count; i++) { - cur.emplace_back(llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f}); - } - } else { - for (llama_token token_id = 0; token_id < (int)sampled_logits_count; token_id++) { - cur.emplace_back(llama_token_data{token_id, sampled_logits[token_id], 0.0f}); - } + for (llama_token i = 0; i < (int)sampled_logits_count; i++) { + cur.emplace_back(llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f}); } } else { const auto * logits = llama_get_logits_ith(ctx, idx);