diff --git a/src/llama-context.cpp b/src/llama-context.cpp index fc05989aa5..828709486a 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -712,8 +712,6 @@ int64_t llama_context::output_resolve_row(int32_t i) const { } float * llama_context::get_logits_ith(int32_t i) { - int64_t j = -1; - output_reorder(); try { @@ -721,26 +719,7 @@ float * llama_context::get_logits_ith(int32_t i) { throw std::runtime_error("no logits"); } - // TODO: use output_resolve_row() - if (i < 0) { - j = n_outputs + i; - if (j < 0) { - throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs)); - } - } else if ((size_t) i >= output_ids.size()) { - throw std::runtime_error(format("out of range [0, %zu)", output_ids.size())); - } else { - j = output_ids[i]; - } - - if (j < 0) { - throw std::runtime_error(format("batch.logits[%d] != true", i)); - } - if (j >= n_outputs) { - // This should not happen - throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); - } - + const int64_t j = output_resolve_row(i); return logits.data + j*model.vocab.n_tokens(); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what()); @@ -763,8 +742,6 @@ llama_token * llama_context::get_sampled_tokens() const{ } float * llama_context::get_embeddings_ith(int32_t i) { - int64_t j = -1; - output_reorder(); try { @@ -772,26 +749,7 @@ float * llama_context::get_embeddings_ith(int32_t i) { throw std::runtime_error("no embeddings"); } - // TODO: use output_resolve_row() - if (i < 0) { - j = n_outputs + i; - if (j < 0) { - throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs)); - } - } else if ((size_t) i >= output_ids.size()) { - throw std::runtime_error(format("out of range [0, %zu)", output_ids.size())); - } else { - j = output_ids[i]; - } - - if (j < 0) { - throw std::runtime_error(format("batch.logits[%d] != true", i)); - } - if (j >= n_outputs) { - // This should not happen - throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); - } - + const int64_t j = output_resolve_row(i); const uint32_t n_embd_out = model.hparams.n_embd_out(); return embd.data + j*n_embd_out; } catch (const std::exception & err) {