Merge ba16f664f1 into 05fa625eac
This commit is contained in:
commit
2afbf724d5
|
|
@ -712,8 +712,6 @@ int64_t llama_context::output_resolve_row(int32_t i) const {
|
|||
}
|
||||
|
||||
float * llama_context::get_logits_ith(int32_t i) {
|
||||
int64_t j = -1;
|
||||
|
||||
output_reorder();
|
||||
|
||||
try {
|
||||
|
|
@ -721,26 +719,7 @@ float * llama_context::get_logits_ith(int32_t i) {
|
|||
throw std::runtime_error("no logits");
|
||||
}
|
||||
|
||||
// TODO: use output_resolve_row()
|
||||
if (i < 0) {
|
||||
j = n_outputs + i;
|
||||
if (j < 0) {
|
||||
throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
|
||||
}
|
||||
} else if ((size_t) i >= output_ids.size()) {
|
||||
throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
|
||||
} else {
|
||||
j = output_ids[i];
|
||||
}
|
||||
|
||||
if (j < 0) {
|
||||
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
||||
}
|
||||
if (j >= n_outputs) {
|
||||
// This should not happen
|
||||
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
|
||||
}
|
||||
|
||||
const int64_t j = output_resolve_row(i);
|
||||
return logits.data + j*model.vocab.n_tokens();
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
|
||||
|
|
@ -763,8 +742,6 @@ llama_token * llama_context::get_sampled_tokens() const{
|
|||
}
|
||||
|
||||
float * llama_context::get_embeddings_ith(int32_t i) {
|
||||
int64_t j = -1;
|
||||
|
||||
output_reorder();
|
||||
|
||||
try {
|
||||
|
|
@ -772,26 +749,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
|
|||
throw std::runtime_error("no embeddings");
|
||||
}
|
||||
|
||||
// TODO: use output_resolve_row()
|
||||
if (i < 0) {
|
||||
j = n_outputs + i;
|
||||
if (j < 0) {
|
||||
throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs));
|
||||
}
|
||||
} else if ((size_t) i >= output_ids.size()) {
|
||||
throw std::runtime_error(format("out of range [0, %zu)", output_ids.size()));
|
||||
} else {
|
||||
j = output_ids[i];
|
||||
}
|
||||
|
||||
if (j < 0) {
|
||||
throw std::runtime_error(format("batch.logits[%d] != true", i));
|
||||
}
|
||||
if (j >= n_outputs) {
|
||||
// This should not happen
|
||||
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
|
||||
}
|
||||
|
||||
const int64_t j = output_resolve_row(i);
|
||||
const uint32_t n_embd_out = model.hparams.n_embd_out();
|
||||
return embd.data + j*n_embd_out;
|
||||
} catch (const std::exception & err) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue