llama : use output_resolve_row() in get_logits_ith/get_embeddings_ith

This commit updates get_logits_ith(), and get_embeddings_ith() to use
output_resolve_row() to resolve the batch index to output row index.

The motivation for this is to remove some code duplication between these
functions.
This commit is contained in:
Daniel Bevenius 2026-02-16 10:34:15 +01:00
parent d5dfc33027
commit ba16f664f1
No known key found for this signature in database
1 changed files with 2 additions and 44 deletions

View File

@ -712,8 +712,6 @@ int64_t llama_context::output_resolve_row(int32_t i) const {
}
float * llama_context::get_logits_ith(int32_t i) {
int64_t j = -1;
output_reorder();
try {
@ -721,26 +719,7 @@ float * llama_context::get_logits_ith(int32_t i) {
throw std::runtime_error("no logits");
}
// TODO: use output_resolve_row()
if (i < 0) {
j = n_outputs + i;
if (j < 0) {
throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
}
} else if ((size_t) i >= output_ids.size()) {
throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
} else {
j = output_ids[i];
}
if (j < 0) {
throw std::runtime_error(format("batch.logits[%d] != true", i));
}
if (j >= n_outputs) {
// This should not happen
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
}
const int64_t j = output_resolve_row(i);
return logits.data + j*model.vocab.n_tokens();
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
@ -763,8 +742,6 @@ llama_token * llama_context::get_sampled_tokens() const{
}
float * llama_context::get_embeddings_ith(int32_t i) {
int64_t j = -1;
output_reorder();
try {
@ -772,26 +749,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
throw std::runtime_error("no embeddings");
}
// TODO: use output_resolve_row()
if (i < 0) {
j = n_outputs + i;
if (j < 0) {
throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
}
} else if ((size_t) i >= output_ids.size()) {
throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
} else {
j = output_ids[i];
}
if (j < 0) {
throw std::runtime_error(format("batch.logits[%d] != true", i));
}
if (j >= n_outputs) {
// This should not happen
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
}
const int64_t j = output_resolve_row(i);
const uint32_t n_embd_out = model.hparams.n_embd_out();
return embd.data + j*n_embd_out;
} catch (const std::exception & err) {