common : fix regression caused by extra memory allocations during sampling
This commit is contained in:
parent
d74eb61aa7
commit
38f408c253
|
|
@ -122,27 +122,24 @@ struct common_sampler {
|
||||||
|
|
||||||
const int n_vocab = llama_vocab_n_tokens(vocab);
|
const int n_vocab = llama_vocab_n_tokens(vocab);
|
||||||
|
|
||||||
// Use the member variable instead of allocating locally
|
|
||||||
cur.clear();
|
|
||||||
|
|
||||||
if (sampled_probs) {
|
if (sampled_probs) {
|
||||||
const uint32_t sampled_probs_count = llama_get_backend_sampled_probs_count_ith(ctx, idx);
|
const uint32_t sampled_probs_count = llama_get_backend_sampled_probs_count_ith(ctx, idx);
|
||||||
cur.reserve(sampled_probs_count);
|
cur.resize(sampled_probs_count);
|
||||||
for (uint32_t i = 0; i < sampled_probs_count; ++i) {
|
for (uint32_t i = 0; i < sampled_probs_count; ++i) {
|
||||||
cur.emplace_back(llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]});
|
cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]};
|
||||||
}
|
}
|
||||||
} else if (sampled_logits) {
|
} else if (sampled_logits) {
|
||||||
const uint32_t sampled_logits_count = llama_get_backend_sampled_logits_count_ith(ctx, idx);
|
const uint32_t sampled_logits_count = llama_get_backend_sampled_logits_count_ith(ctx, idx);
|
||||||
cur.reserve(sampled_logits_count);
|
cur.resize(sampled_logits_count);
|
||||||
for (uint32_t i = 0; i < sampled_logits_count; i++) {
|
for (uint32_t i = 0; i < sampled_logits_count; i++) {
|
||||||
cur.emplace_back(llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f});
|
cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f};
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||||
GGML_ASSERT(logits != nullptr);
|
GGML_ASSERT(logits != nullptr);
|
||||||
cur.reserve(n_vocab);
|
cur.resize(n_vocab);
|
||||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||||
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue