sampling : support multiple outputs per sequence
This commit adds support for multiple outputs per sequence in the backend sampling implementation. The main motivation for this change is to be able to support speculative decoding using backend samplers where multiple outputs for the same sequence would be needed.
This commit is contained in:
parent
a8b192b6ec
commit
1138d5c2d9
|
|
@ -1296,8 +1296,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
static std::map<llama_seq_id, uint32_t> build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset) {
|
||||
std::map<llama_seq_id, uint32_t> seq_to_row;
|
||||
static std::map<llama_seq_id, std::vector<uint32_t>> build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset) {
|
||||
std::map<llama_seq_id, std::vector<uint32_t>> seq_to_row;
|
||||
// how many output tokens we have seen so far for this ubatch.
|
||||
uint32_t local = 0;
|
||||
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
||||
|
|
@ -1308,96 +1308,114 @@ static std::map<llama_seq_id, uint32_t> build_seq_to_output_row(const llama_ubat
|
|||
|
||||
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
||||
// row_offset is the number of output tokens before this ubatch.
|
||||
seq_to_row[seq_id] = row_offset + local;
|
||||
seq_to_row[seq_id].push_back(row_offset + local);
|
||||
++local;
|
||||
}
|
||||
return seq_to_row;
|
||||
}
|
||||
|
||||
static void copy_tensor_async_ints(
|
||||
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
|
||||
const std::map<llama_seq_id, std::vector<ggml_tensor*>> & tensor_map,
|
||||
const buffer_view<llama_token> & sampled,
|
||||
const std::map<llama_seq_id, uint32_t> & seq_to_row,
|
||||
const std::map<llama_seq_id, std::vector<uint32_t>> & seq_to_row,
|
||||
ggml_backend_sched_t sched) {
|
||||
if (!sampled.has_data()) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (const auto & [seq_id, tensor] : tensor_map) {
|
||||
for (const auto & [seq_id, tensors] : tensor_map) {
|
||||
auto it = seq_to_row.find(seq_id);
|
||||
if (it == seq_to_row.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const uint32_t row = it->second;
|
||||
GGML_ASSERT(row < sampled.size);
|
||||
const std::vector<uint32_t> & rows = it->second;
|
||||
GGML_ASSERT(tensors.size() == rows.size() && "number of tensors must match number of output rows");
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy");
|
||||
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||
const uint32_t row = rows[i];
|
||||
ggml_tensor * tensor = tensors[i];
|
||||
|
||||
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
|
||||
ggml_backend_tensor_get_async(backend, tensor, sampled.data + row, 0, sizeof(sampled.data[row]));
|
||||
GGML_ASSERT(row < sampled.size);
|
||||
GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy");
|
||||
|
||||
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
|
||||
ggml_backend_tensor_get_async(backend, tensor, sampled.data + row, 0, sizeof(sampled.data[row]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void copy_tensor_async_floats(
|
||||
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
|
||||
const std::map<llama_seq_id, std::vector<ggml_tensor*>> & tensor_map,
|
||||
const buffer_view<float> & dst,
|
||||
size_t stride,
|
||||
std::vector<uint32_t> & counts,
|
||||
const std::map<llama_seq_id, uint32_t> & seq_to_row,
|
||||
const std::map<llama_seq_id, std::vector<uint32_t>> & seq_to_row,
|
||||
ggml_backend_sched_t sched) {
|
||||
if (!dst.has_data()) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (const auto & [seq_id, tensor] : tensor_map) {
|
||||
for (const auto & [seq_id, tensors] : tensor_map) {
|
||||
auto it = seq_to_row.find(seq_id);
|
||||
if (it == seq_to_row.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const uint32_t row = it->second;
|
||||
GGML_ASSERT(row < counts.size());
|
||||
const std::vector<uint32_t> & rows = it->second;
|
||||
GGML_ASSERT(tensors.size() == rows.size() && "number of tensors must match number of output rows");
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy");
|
||||
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||
const uint32_t row = rows[i];
|
||||
ggml_tensor * tensor = tensors[i];
|
||||
|
||||
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
|
||||
float * row_ptr = dst.data + (size_t) row * stride;
|
||||
ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
|
||||
GGML_ASSERT(row < counts.size());
|
||||
GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy");
|
||||
|
||||
// Update the actual number of logits/probabilities that were written for this row.
|
||||
counts[row] = ggml_nelements(tensor);
|
||||
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
|
||||
float * row_ptr = dst.data + (size_t) row * stride;
|
||||
ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
|
||||
|
||||
// Update the actual number of logits/probabilities that were written for this row.
|
||||
counts[row] = ggml_nelements(tensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void copy_tensor_async_candidates(
|
||||
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
|
||||
const std::map<llama_seq_id, std::vector<ggml_tensor*>> & tensor_map,
|
||||
const buffer_view<llama_token> & dst,
|
||||
size_t stride,
|
||||
std::vector<uint32_t> & counts,
|
||||
const std::map<llama_seq_id, uint32_t> & seq_to_row,
|
||||
const std::map<llama_seq_id, std::vector<uint32_t>> & seq_to_row,
|
||||
ggml_backend_sched_t sched) {
|
||||
if (!dst.has_data()) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (const auto & [seq_id, tensor] : tensor_map) {
|
||||
for (const auto & [seq_id, tensors] : tensor_map) {
|
||||
auto it = seq_to_row.find(seq_id);
|
||||
if (it == seq_to_row.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const uint32_t row = it->second;
|
||||
GGML_ASSERT(row < counts.size());
|
||||
const std::vector<uint32_t> & rows = it->second;
|
||||
GGML_ASSERT(tensors.size() == rows.size() && "number of tensors must match number of output rows");
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy");
|
||||
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||
const uint32_t row = rows[i];
|
||||
ggml_tensor * tensor = tensors[i];
|
||||
|
||||
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
|
||||
llama_token * row_ptr = dst.data + (size_t) row * stride;
|
||||
ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
|
||||
GGML_ASSERT(row < counts.size());
|
||||
GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy");
|
||||
|
||||
// Update the actual number of candidates that were written.
|
||||
counts[row] = ggml_nelements(tensor);
|
||||
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
|
||||
llama_token * row_ptr = dst.data + (size_t) row * stride;
|
||||
ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor));
|
||||
|
||||
// Update the actual number of candidates that were written.
|
||||
counts[row] = ggml_nelements(tensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1443,30 +1461,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
|
||||
const uint32_t n_seq_max = cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max;
|
||||
|
||||
// TODO: avoid this workaround in the future
|
||||
if (has_samplers && batch_inp.logits) {
|
||||
std::vector<int32_t> seq_output_count(n_seq_max, 0);
|
||||
|
||||
for (int32_t i = 0; i < batch_inp.n_tokens; ++i) {
|
||||
if (batch_inp.logits[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int ns = batch_inp.n_seq_id ? batch_inp.n_seq_id[i] : 1;
|
||||
|
||||
for (int32_t s = 0; s < ns; ++s) {
|
||||
const llama_seq_id seq_id = batch_inp.seq_id ? batch_inp.seq_id[i][s] : 0;
|
||||
|
||||
seq_output_count[seq_id]++;
|
||||
if (seq_output_count[seq_id] > 1) {
|
||||
LLAMA_LOG_ERROR("%s: backend sampling requires at most one output token per sequence (seq_id %d had %d)\n",
|
||||
__func__, seq_id, seq_output_count[seq_id]);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, n_seq_max, output_all)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||
return -1;
|
||||
|
|
|
|||
|
|
@ -774,24 +774,24 @@ void llm_graph_result::set_outputs() {
|
|||
if (t_embd_pooled != nullptr) {
|
||||
ggml_set_output(t_embd_pooled);
|
||||
}
|
||||
for (auto & [seq_id, t] : t_sampled) {
|
||||
if (t != nullptr) {
|
||||
ggml_set_output(t);
|
||||
for (auto & [seq_id, tensors] : t_sampled) {
|
||||
for (ggml_tensor * tensor : tensors) {
|
||||
ggml_set_output(tensor);
|
||||
}
|
||||
}
|
||||
for (auto & [seq_id, t] : t_sampled_probs) {
|
||||
if (t != nullptr) {
|
||||
ggml_set_output(t);
|
||||
for (auto & [seq_id, tensors] : t_sampled_probs) {
|
||||
for (ggml_tensor * tensor : tensors) {
|
||||
ggml_set_output(tensor);
|
||||
}
|
||||
}
|
||||
for (auto & [seq_id, t] : t_sampled_logits) {
|
||||
if (t != nullptr) {
|
||||
ggml_set_output(t);
|
||||
for (auto & [seq_id, tensors] : t_sampled_logits) {
|
||||
for (ggml_tensor * tensor : tensors) {
|
||||
ggml_set_output(tensor);
|
||||
}
|
||||
}
|
||||
for (auto & [seq_id, t] : t_candidates) {
|
||||
if (t != nullptr) {
|
||||
ggml_set_output(t);
|
||||
for (auto & [seq_id, tensors] : t_candidates) {
|
||||
for (ggml_tensor * tensor : tensors) {
|
||||
ggml_set_output(tensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -2580,13 +2580,13 @@ void llm_graph_context::build_sampling() const {
|
|||
auto inp_sampling = std::make_unique<llm_graph_input_sampling>(samplers);
|
||||
res->add_input(std::move(inp_sampling));
|
||||
|
||||
std::map<llama_seq_id, int32_t> seq_to_logit_row;
|
||||
std::map<llama_seq_id, std::vector<int32_t>> seq_to_logit_rows;
|
||||
int32_t logit_row_idx = 0;
|
||||
|
||||
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
||||
if (ubatch.output[i]) {
|
||||
llama_seq_id seq_id = ubatch.seq_id[i][0];
|
||||
seq_to_logit_row[seq_id] = logit_row_idx;
|
||||
seq_to_logit_rows[seq_id].push_back(logit_row_idx);
|
||||
logit_row_idx++;
|
||||
}
|
||||
}
|
||||
|
|
@ -2600,47 +2600,52 @@ void llm_graph_context::build_sampling() const {
|
|||
ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0);
|
||||
|
||||
for (const auto & [seq_id, sampler] : samplers) {
|
||||
const auto it = seq_to_logit_row.find(seq_id);
|
||||
const auto row_it = seq_to_logit_rows.find(seq_id);
|
||||
|
||||
// inactive samplers always work on the first row
|
||||
const auto row_idx = it != seq_to_logit_row.end() ? it->second : 0;
|
||||
const int i_out = it != seq_to_logit_row.end() ? 1 : 0;
|
||||
// row_it is now a sequence id to list of row ids
|
||||
static const std::vector<int32_t> default_row = {0};
|
||||
const std::vector<int32_t> & logit_rows = row_it != seq_to_logit_rows.end() ? row_it->second : default_row;
|
||||
for (const int32_t row_idx : logit_rows) {
|
||||
|
||||
ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]);
|
||||
ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
|
||||
// inactive samplers always work on the first row
|
||||
const int i_out = row_it != seq_to_logit_rows.end() ? 1 : 0;
|
||||
|
||||
struct llama_sampler_data data = {
|
||||
/*.logits =*/ logits_seq,
|
||||
/*.probs =*/ nullptr,
|
||||
/*.sampled =*/ nullptr,
|
||||
/*.candidates =*/ nullptr,
|
||||
};
|
||||
ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]);
|
||||
ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
|
||||
|
||||
assert(sampler->iface->backend_apply);
|
||||
sampler->iface->backend_apply(sampler, ctx0, gf, &data);
|
||||
struct llama_sampler_data data = {
|
||||
/*.logits =*/ logits_seq,
|
||||
/*.probs =*/ nullptr,
|
||||
/*.sampled =*/ nullptr,
|
||||
/*.candidates =*/ nullptr,
|
||||
};
|
||||
|
||||
if (data.sampled != nullptr) {
|
||||
res->t_sampled[seq_id] = data.sampled;
|
||||
outs[1] = data.sampled;
|
||||
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
|
||||
}
|
||||
assert(sampler->iface->backend_apply);
|
||||
sampler->iface->backend_apply(sampler, ctx0, gf, &data);
|
||||
|
||||
if (data.probs != nullptr) {
|
||||
res->t_sampled_probs[seq_id] = data.probs;
|
||||
outs[1] = data.probs;
|
||||
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
|
||||
}
|
||||
if (data.sampled != nullptr) {
|
||||
res->t_sampled[seq_id].push_back(data.sampled);
|
||||
outs[1] = data.sampled;
|
||||
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
|
||||
}
|
||||
|
||||
if (data.logits != nullptr) {
|
||||
res->t_sampled_logits[seq_id] = data.logits;
|
||||
outs[1] = data.logits;
|
||||
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
|
||||
}
|
||||
if (data.probs != nullptr) {
|
||||
res->t_sampled_probs[seq_id].push_back(data.probs);
|
||||
outs[1] = data.probs;
|
||||
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
|
||||
}
|
||||
|
||||
if (data.candidates != nullptr) {
|
||||
res->t_candidates[seq_id] = data.candidates;
|
||||
outs[1] = data.candidates;
|
||||
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
|
||||
if (data.logits != nullptr) {
|
||||
res->t_sampled_logits[seq_id].push_back(data.logits);
|
||||
outs[1] = data.logits;
|
||||
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
|
||||
}
|
||||
|
||||
if (data.candidates != nullptr) {
|
||||
res->t_candidates[seq_id].push_back(data.candidates);
|
||||
outs[1] = data.candidates;
|
||||
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -662,10 +662,10 @@ public:
|
|||
ggml_tensor * t_embd = nullptr;
|
||||
ggml_tensor * t_embd_pooled = nullptr;
|
||||
|
||||
std::map<llama_seq_id, ggml_tensor*> t_sampled_logits;
|
||||
std::map<llama_seq_id, ggml_tensor*> t_candidates;
|
||||
std::map<llama_seq_id, ggml_tensor*> t_sampled;
|
||||
std::map<llama_seq_id, ggml_tensor*> t_sampled_probs;
|
||||
std::map<llama_seq_id, std::vector<ggml_tensor*>> t_sampled_logits;
|
||||
std::map<llama_seq_id, std::vector<ggml_tensor*>> t_candidates;
|
||||
std::map<llama_seq_id, std::vector<ggml_tensor*>> t_sampled;
|
||||
std::map<llama_seq_id, std::vector<ggml_tensor*>> t_sampled_probs;
|
||||
|
||||
std::vector<llm_graph_input_ptr> inputs;
|
||||
|
||||
|
|
|
|||
|
|
@ -968,7 +968,7 @@ static void test_backend_cpu_mixed_batch(const test_params & params) {
|
|||
printf("backend-cpu mixed batch test PASSED\n");
|
||||
}
|
||||
|
||||
static void test_backend_max_outputs(const test_params & params) {
|
||||
static void test_backend_multiple_outputs(const test_params & params) {
|
||||
const int seq_id = 0;
|
||||
const int32_t seed = 88;
|
||||
|
||||
|
|
@ -994,17 +994,32 @@ static void test_backend_max_outputs(const test_params & params) {
|
|||
}
|
||||
|
||||
for (size_t i = 0; i < tokens.size(); i++) {
|
||||
// set all tokens as output to trigger error
|
||||
// set all tokens as output to get multiple outputs for a single sequence.
|
||||
common_batch_add(batch, tokens[i], i, { seq_id }, true);
|
||||
}
|
||||
|
||||
printf(">>> test_max_outputs expected error start:\n");
|
||||
const int ret = llama_decode(test_ctx.ctx.get(), batch);
|
||||
GGML_ASSERT(ret != 0 && "llama_decode should not succeed multiple outputs per sequence");
|
||||
printf("<<< test_max_outputs expected error end.\n");
|
||||
if (ret != 0) {
|
||||
GGML_ASSERT(false && "Failed to decode sequence with multiple outputs");
|
||||
}
|
||||
|
||||
std::vector<llama_token> sampled_tokens;
|
||||
for (int i = 0; i < batch.n_tokens; i++) {
|
||||
if (batch.logits[i]) {
|
||||
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), i);
|
||||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
||||
//printf("Position %d: token id=%d, string='%s'\n", i, token, token_str.c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
sampled_tokens.push_back(token);
|
||||
}
|
||||
}
|
||||
|
||||
GGML_ASSERT((int)sampled_tokens.size() == batch.n_tokens);
|
||||
printf("Sampled %zu tokens for sequence %d\n", sampled_tokens.size(), seq_id);
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
printf("backend max outputs test PASSED\n");
|
||||
printf("backend multiple outputs test PASSED\n");
|
||||
}
|
||||
|
||||
struct backend_test_case {
|
||||
|
|
@ -1023,7 +1038,7 @@ static const backend_test_case BACKEND_TESTS[] = {
|
|||
{ "dist", test_backend_dist_sampling, true },
|
||||
{ "dist_and_cpu", test_backend_dist_sampling_and_cpu, true },
|
||||
{ "set_sampler", test_backend_set_sampler, true },
|
||||
{ "max_outputs", test_backend_max_outputs, true },
|
||||
{ "multiple_outputs",test_backend_multiple_outputs, true },
|
||||
{ "mixed", test_backend_mixed_sampling, true },
|
||||
{ "min_p", test_backend_min_p_sampling, true },
|
||||
{ "cpu_mixed", test_backend_cpu_mixed_batch, true },
|
||||
|
|
|
|||
Loading…
Reference in New Issue