diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 2cffa524cd..11679c6c9e 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -459,22 +459,22 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte if (sampled_probs) { const uint32_t sampled_probs_count = llama_get_backend_sampled_probs_count_ith(ctx, idx); - cur.reserve(sampled_probs_count); + cur.resize(sampled_probs_count); for (uint32_t i = 0; i < sampled_probs_count; ++i) { - cur.emplace_back(llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]}); + cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]}; } } else if (sampled_logits) { const uint32_t sampled_logits_count = llama_get_backend_sampled_logits_count_ith(ctx, idx); - cur.reserve(sampled_logits_count); + cur.resize(sampled_logits_count); for (llama_token i = 0; i < (int)sampled_logits_count; i++) { - cur.emplace_back(llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f}); + cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f}; } } else { const auto * logits = llama_get_logits_ith(ctx, idx); GGML_ASSERT(logits != nullptr); - cur.reserve(n_vocab); + cur.resize(n_vocab); for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; } }