squash! common : fix regression caused by extra memory allocations during sampling

Apply the same changes to llama-sampling.cpp, llama_sampler_sample as
were applied in commit 38f408c25.
This commit is contained in:
Daniel Bevenius 2025-11-20 07:56:33 +01:00
parent 0c660e7390
commit ed4345bdd9
1 changed files with 6 additions and 6 deletions

View File

@ -459,22 +459,22 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte
if (sampled_probs) { if (sampled_probs) {
const uint32_t sampled_probs_count = llama_get_backend_sampled_probs_count_ith(ctx, idx); const uint32_t sampled_probs_count = llama_get_backend_sampled_probs_count_ith(ctx, idx);
cur.reserve(sampled_probs_count); cur.resize(sampled_probs_count);
for (uint32_t i = 0; i < sampled_probs_count; ++i) { for (uint32_t i = 0; i < sampled_probs_count; ++i) {
cur.emplace_back(llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]}); cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]};
} }
} else if (sampled_logits) { } else if (sampled_logits) {
const uint32_t sampled_logits_count = llama_get_backend_sampled_logits_count_ith(ctx, idx); const uint32_t sampled_logits_count = llama_get_backend_sampled_logits_count_ith(ctx, idx);
cur.reserve(sampled_logits_count); cur.resize(sampled_logits_count);
for (llama_token i = 0; i < (int)sampled_logits_count; i++) { for (llama_token i = 0; i < (int)sampled_logits_count; i++) {
cur.emplace_back(llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f}); cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f};
} }
} else { } else {
const auto * logits = llama_get_logits_ith(ctx, idx); const auto * logits = llama_get_logits_ith(ctx, idx);
GGML_ASSERT(logits != nullptr); GGML_ASSERT(logits != nullptr);
cur.reserve(n_vocab); cur.resize(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) { for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
} }
} }