Merge ac6d09c63c into 06bf3796f4
This commit is contained in:
commit
c72416f496
|
|
@ -757,7 +757,7 @@ float * llama_context::get_embeddings() {
|
|||
}
|
||||
|
||||
llama_token * llama_context::get_sampled_tokens() const{
|
||||
return sampling.sampled;
|
||||
return sampling.sampled.data;
|
||||
}
|
||||
|
||||
float * llama_context::get_embeddings_ith(int32_t i) {
|
||||
|
|
@ -814,14 +814,14 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
|
|||
llama_token llama_context::get_sampled_token_ith(int32_t idx) {
|
||||
output_reorder();
|
||||
|
||||
if (sampling.sampled == nullptr) {
|
||||
if (!sampling.sampled.has_data()) {
|
||||
return LLAMA_TOKEN_NULL;
|
||||
}
|
||||
|
||||
try {
|
||||
const int64_t row = output_resolve_row(idx);
|
||||
GGML_ASSERT(row < (int64_t) sampling.sampled_size);
|
||||
return sampling.sampled[row];
|
||||
GGML_ASSERT(row < (int64_t) sampling.sampled.size);
|
||||
return sampling.sampled.data[row];
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: invalid backend sampled token id %d, reason: %s\n", __func__, idx, err.what());
|
||||
return LLAMA_TOKEN_NULL;
|
||||
|
|
@ -831,7 +831,7 @@ llama_token llama_context::get_sampled_token_ith(int32_t idx) {
|
|||
float * llama_context::get_sampled_probs_ith(int32_t idx) {
|
||||
output_reorder();
|
||||
|
||||
if (sampling.probs == nullptr) {
|
||||
if (!sampling.probs.has_data()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
@ -840,7 +840,7 @@ float * llama_context::get_sampled_probs_ith(int32_t idx) {
|
|||
if ((size_t) row >= sampling.probs_count.size() || sampling.probs_count[row] == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
return sampling.probs + row*model.vocab.n_tokens();
|
||||
return sampling.probs.data + row*model.vocab.n_tokens();
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: invalid backend sampled probs id %d, reason: %s\n", __func__, idx, err.what());
|
||||
return nullptr;
|
||||
|
|
@ -850,7 +850,7 @@ float * llama_context::get_sampled_probs_ith(int32_t idx) {
|
|||
float * llama_context::get_sampled_logits_ith(int32_t idx) {
|
||||
output_reorder();
|
||||
|
||||
if (sampling.logits == nullptr) {
|
||||
if (!sampling.logits.has_data()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
@ -859,7 +859,7 @@ float * llama_context::get_sampled_logits_ith(int32_t idx) {
|
|||
if ((size_t) row >= sampling.logits_count.size() || sampling.logits_count[row] == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
return sampling.logits + row*model.vocab.n_tokens();
|
||||
return sampling.logits.data + row*model.vocab.n_tokens();
|
||||
} catch (const std::exception & err) {
|
||||
LLAMA_LOG_ERROR("%s: invalid backend sampled logits id %d, reason: %s\n", __func__, idx, err.what());
|
||||
return nullptr;
|
||||
|
|
@ -871,10 +871,10 @@ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) {
|
|||
|
||||
try {
|
||||
const int64_t row = output_resolve_row(idx);
|
||||
if (sampling.candidates != nullptr &&
|
||||
if (sampling.candidates.has_data() &&
|
||||
(size_t) row < sampling.candidates_count.size() &&
|
||||
sampling.candidates_count[row] > 0) {
|
||||
return sampling.candidates + row*model.vocab.n_tokens();
|
||||
return sampling.candidates.data + row*model.vocab.n_tokens();
|
||||
}
|
||||
} catch (const std::exception & err) {
|
||||
// fallback to full vocab list
|
||||
|
|
@ -886,7 +886,7 @@ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) {
|
|||
size_t llama_context::get_sampled_candidates_count(int32_t idx) {
|
||||
output_reorder();
|
||||
|
||||
if (sampling.candidates == nullptr) {
|
||||
if (!sampling.candidates.has_data()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
@ -905,7 +905,7 @@ size_t llama_context::get_sampled_candidates_count(int32_t idx) {
|
|||
size_t llama_context::get_sampled_logits_count(int32_t idx) {
|
||||
output_reorder();
|
||||
|
||||
if (sampling.logits == nullptr) {
|
||||
if (!sampling.logits.has_data()) {
|
||||
return model.vocab.n_tokens();
|
||||
}
|
||||
|
||||
|
|
@ -924,7 +924,7 @@ size_t llama_context::get_sampled_logits_count(int32_t idx) {
|
|||
size_t llama_context::get_sampled_probs_count(int32_t idx) {
|
||||
output_reorder();
|
||||
|
||||
if (sampling.probs == nullptr) {
|
||||
if (!sampling.probs.has_data()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
@ -1363,11 +1363,10 @@ static std::map<llama_seq_id, uint32_t> build_seq_to_output_row(const llama_ubat
|
|||
|
||||
static void copy_tensor_async_ints(
|
||||
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
|
||||
llama_token * sampled,
|
||||
size_t sampled_size,
|
||||
const buffer_view<llama_token> & sampled,
|
||||
const std::map<llama_seq_id, uint32_t> & seq_to_row,
|
||||
ggml_backend_sched_t sched) {
|
||||
if (sampled == nullptr) {
|
||||
if (!sampled.has_data()) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -1378,23 +1377,23 @@ static void copy_tensor_async_ints(
|
|||
}
|
||||
|
||||
const uint32_t row = it->second;
|
||||
GGML_ASSERT(row < sampled_size);
|
||||
GGML_ASSERT(row < sampled.size);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy");
|
||||
|
||||
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
|
||||
ggml_backend_tensor_get_async(backend, tensor, sampled + row, 0, sizeof(sampled[row]));
|
||||
ggml_backend_tensor_get_async(backend, tensor, sampled.data + row, 0, sizeof(sampled.data[row]));
|
||||
}
|
||||
}
|
||||
|
||||
static void copy_tensor_async_floats(
|
||||
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
|
||||
float * dst,
|
||||
const buffer_view<float> & dst,
|
||||
size_t stride,
|
||||
std::vector<uint32_t> & counts,
|
||||
const std::map<llama_seq_id, uint32_t> & seq_to_row,
|
||||
ggml_backend_sched_t sched) {
|
||||
if (dst == nullptr) {
|
||||
if (!dst.has_data()) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -1410,7 +1409,7 @@ static void copy_tensor_async_floats(
|
|||
GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy");
|
||||
|
||||
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
|
||||
float * row_ptr = dst + (size_t) row * stride;
|
||||
float * row_ptr = dst.data + (size_t) row * stride;
|
||||
ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
|
||||
|
||||
// Update the actual number of logits/probabilities that were written for this row.
|
||||
|
|
@ -1420,12 +1419,12 @@ static void copy_tensor_async_floats(
|
|||
|
||||
static void copy_tensor_async_candidates(
|
||||
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
|
||||
llama_token * dst,
|
||||
const buffer_view<llama_token> & dst,
|
||||
size_t stride,
|
||||
std::vector<uint32_t> & counts,
|
||||
const std::map<llama_seq_id, uint32_t> & seq_to_row,
|
||||
ggml_backend_sched_t sched) {
|
||||
if (dst == nullptr) {
|
||||
if (!dst.has_data()) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -1441,7 +1440,7 @@ static void copy_tensor_async_candidates(
|
|||
GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy");
|
||||
|
||||
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
|
||||
llama_token * row_ptr = dst + (size_t) row * stride;
|
||||
llama_token * row_ptr = dst.data + (size_t) row * stride;
|
||||
ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
|
||||
|
||||
// Update the actual number of candidates that were written.
|
||||
|
|
@ -1747,7 +1746,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
const auto stride = n_vocab;
|
||||
|
||||
// async copy the sampling data from the backend to the host
|
||||
copy_tensor_async_ints(res->t_sampled, sampling.sampled, sampling.sampled_size, seq_to_output_row, sched.get());
|
||||
copy_tensor_async_ints(res->t_sampled, sampling.sampled, seq_to_output_row, sched.get());
|
||||
|
||||
copy_tensor_async_floats (res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, seq_to_output_row, sched.get());
|
||||
copy_tensor_async_floats (res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, seq_to_output_row, sched.get());
|
||||
|
|
@ -1847,13 +1846,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|||
// Allocate backend sampling output buffers if there are backend samplers configured.
|
||||
const bool has_sampling = !sampling.samplers.empty();
|
||||
if (has_sampling) {
|
||||
sampling.logits_size = n_vocab*n_outputs_max;
|
||||
sampling.probs_size = n_vocab*n_outputs_max;
|
||||
sampling.sampled_size = n_outputs_max;
|
||||
sampling.candidates_size = n_vocab*n_outputs_max;
|
||||
|
||||
backend_float_count = sampling.logits_size + sampling.probs_size;
|
||||
backend_token_count = sampling.sampled_size + sampling.candidates_size;
|
||||
backend_float_count = 2 * n_vocab * n_outputs_max; // logits + probs
|
||||
backend_token_count = (1 + n_vocab) * n_outputs_max; // sampled + candidates
|
||||
}
|
||||
|
||||
if (output_ids.empty()) {
|
||||
|
|
@ -1910,23 +1904,23 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|||
embd = has_embd ? (float *) (base + offset) : nullptr;
|
||||
offset += embd_size * sizeof(float);
|
||||
|
||||
sampling.logits = nullptr;
|
||||
sampling.probs = nullptr;
|
||||
sampling.sampled = nullptr;
|
||||
sampling.candidates = nullptr;
|
||||
sampling.logits = {nullptr, 0};
|
||||
sampling.probs = {nullptr, 0};
|
||||
sampling.sampled = {nullptr, 0};
|
||||
sampling.candidates = {nullptr, 0};
|
||||
|
||||
if (has_sampling) {
|
||||
sampling.logits = (float *) (base + offset);
|
||||
offset += sampling.logits_size * sizeof(float);
|
||||
sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
|
||||
offset += sampling.logits.size * sizeof(float);
|
||||
|
||||
sampling.probs = (float *) (base + offset);
|
||||
offset += sampling.probs_size * sizeof(float);
|
||||
sampling.probs = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
|
||||
offset += sampling.probs.size * sizeof(float);
|
||||
|
||||
sampling.sampled = (llama_token *) (base + offset);
|
||||
offset += sampling.sampled_size * sizeof(llama_token);
|
||||
sampling.sampled = {(llama_token *) (base + offset), (size_t)n_outputs_max};
|
||||
offset += sampling.sampled.size * sizeof(llama_token);
|
||||
|
||||
sampling.candidates = (llama_token *) (base + offset);
|
||||
offset += sampling.candidates_size * sizeof(llama_token);
|
||||
sampling.candidates = {(llama_token *) (base + offset), (size_t)(n_vocab*n_outputs_max)};
|
||||
offset += sampling.candidates.size * sizeof(llama_token);
|
||||
|
||||
// The count vectors keep track of the actual number of logits/probs/candidates
|
||||
// copied from the backend for each output row.
|
||||
|
|
@ -1939,7 +1933,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|||
std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0);
|
||||
std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);
|
||||
|
||||
std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL);
|
||||
std::fill_n(sampling.sampled.data, sampling.sampled.size, LLAMA_TOKEN_NULL);
|
||||
}
|
||||
|
||||
// set all ids as invalid (negative)
|
||||
|
|
@ -1970,26 +1964,26 @@ void llama_context::output_reorder() {
|
|||
}
|
||||
}
|
||||
|
||||
if (sampling.logits && sampling.logits_size > 0) {
|
||||
if (sampling.logits.has_data()) {
|
||||
for (uint64_t k = 0; k < n_vocab; ++k) {
|
||||
std::swap(sampling.logits[i0*n_vocab + k], sampling.logits[i1*n_vocab + k]);
|
||||
std::swap(sampling.logits.data[i0*n_vocab + k], sampling.logits.data[i1*n_vocab + k]);
|
||||
}
|
||||
}
|
||||
|
||||
if (sampling.probs && sampling.probs_size > 0) {
|
||||
if (sampling.probs.has_data()) {
|
||||
for (uint64_t k = 0; k < n_vocab; ++k) {
|
||||
std::swap(sampling.probs[i0*n_vocab + k], sampling.probs[i1*n_vocab + k]);
|
||||
std::swap(sampling.probs.data[i0*n_vocab + k], sampling.probs.data[i1*n_vocab + k]);
|
||||
}
|
||||
}
|
||||
|
||||
if (sampling.candidates && sampling.candidates_size > 0) {
|
||||
if (sampling.candidates.has_data()) {
|
||||
for (uint64_t k = 0; k < n_vocab; ++k) {
|
||||
std::swap(sampling.candidates[i0*n_vocab + k], sampling.candidates[i1*n_vocab + k]);
|
||||
std::swap(sampling.candidates.data[i0*n_vocab + k], sampling.candidates.data[i1*n_vocab + k]);
|
||||
}
|
||||
}
|
||||
|
||||
if (sampling.sampled && sampling.sampled_size > 0) {
|
||||
std::swap(sampling.sampled[i0], sampling.sampled[i1]);
|
||||
if (sampling.sampled.has_data()) {
|
||||
std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]);
|
||||
}
|
||||
|
||||
if (!sampling.logits_count.empty()) {
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#include "llama-cparams.h"
|
||||
#include "llama-graph.h"
|
||||
#include "llama-adapter.h"
|
||||
#include "llama-impl.h"
|
||||
|
||||
#include "ggml-cpp.h"
|
||||
#include "ggml-opt.h"
|
||||
|
|
@ -277,21 +278,13 @@ private:
|
|||
size_t embd_size = 0; // capacity (of floats) for embeddings
|
||||
float * embd = nullptr;
|
||||
|
||||
// TODO: simplify
|
||||
struct sampling_info {
|
||||
std::map<llama_seq_id, llama_sampler *> samplers;
|
||||
|
||||
float * logits = nullptr;
|
||||
size_t logits_size = 0;
|
||||
|
||||
llama_token * sampled = nullptr;
|
||||
size_t sampled_size = 0;
|
||||
|
||||
float * probs = nullptr;
|
||||
size_t probs_size = 0;
|
||||
|
||||
llama_token * candidates = nullptr;
|
||||
size_t candidates_size = 0;
|
||||
struct buffer_view<float> logits = {nullptr, 0};
|
||||
struct buffer_view<llama_token> sampled = {nullptr, 0};
|
||||
struct buffer_view<float> probs = {nullptr, 0};
|
||||
struct buffer_view<llama_token> candidates = {nullptr, 0};
|
||||
|
||||
std::vector<uint32_t> logits_count;
|
||||
std::vector<uint32_t> probs_count;
|
||||
|
|
|
|||
|
|
@ -49,6 +49,16 @@ struct time_meas {
|
|||
int64_t & t_acc;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct buffer_view {
|
||||
T * data;
|
||||
size_t size = 0;
|
||||
|
||||
bool has_data() const {
|
||||
return data && size > 0;
|
||||
}
|
||||
};
|
||||
|
||||
void replace_all(std::string & s, const std::string & search, const std::string & replace);
|
||||
|
||||
// TODO: rename to llama_format ?
|
||||
|
|
|
|||
Loading…
Reference in New Issue