diff --git a/common/sampling.cpp b/common/sampling.cpp index 9c707a5bb9..9813762eca 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -129,7 +129,7 @@ struct common_sampler { const uint32_t sampled_probs_count = llama_get_backend_sampled_probs_count_ith(ctx, idx); cur.reserve(sampled_probs_count); for (uint32_t i = 0; i < sampled_probs_count; ++i) { - cur.emplace_back(llama_token_data{sampled_ids[i], 0.0f, sampled_probs[i]}); + cur.emplace_back(llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]}); } } else if (sampled_logits) { const uint32_t sampled_logits_count = llama_get_backend_sampled_logits_count_ith(ctx, idx); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 456e050201..2cffa524cd 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -461,7 +461,7 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte const uint32_t sampled_probs_count = llama_get_backend_sampled_probs_count_ith(ctx, idx); cur.reserve(sampled_probs_count); for (uint32_t i = 0; i < sampled_probs_count; ++i) { - cur.emplace_back(llama_token_data{sampled_ids[i], 0.0f, sampled_probs[i]}); + cur.emplace_back(llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]}); } } else if (sampled_logits) { const uint32_t sampled_logits_count = llama_get_backend_sampled_logits_count_ith(ctx, idx);