sampling : simplify backend sampling logic decode
This commit tries to simplify the backend sampling logic in llama_context::decode.
This commit is contained in:
parent
51fee29822
commit
7e98ebcc6b
|
|
@ -1133,6 +1133,54 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
static std::unordered_map<llama_seq_id, int32_t> build_seq_to_batch_idx(const llama_ubatch & ubatch) {
|
||||
std::unordered_map<llama_seq_id, int32_t> seq_to_batch_idx;
|
||||
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
||||
if (ubatch.output[i]) {
|
||||
seq_to_batch_idx[ubatch.seq_id[i][0]] = i;
|
||||
}
|
||||
}
|
||||
return seq_to_batch_idx;
|
||||
}
|
||||
|
||||
static void copy_tensor_async_int(
|
||||
const std::unordered_map<llama_seq_id, ggml_tensor*> & tensor_map,
|
||||
std::unordered_map<int32_t, llama_token> & output_map,
|
||||
const std::unordered_map<llama_seq_id, int32_t> & seq_to_batch_idx,
|
||||
ggml_backend_sched_t sched) {
|
||||
for (const auto & [seq_id, tensor] : tensor_map) {
|
||||
const int32_t idx = seq_to_batch_idx.at(seq_id);
|
||||
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
|
||||
ggml_backend_tensor_get_async(backend, tensor, &output_map[idx], 0, sizeof(output_map[idx]));
|
||||
}
|
||||
}
|
||||
|
||||
static void copy_tensor_async_floats(
|
||||
const std::unordered_map<llama_seq_id, ggml_tensor*> & tensor_map,
|
||||
std::unordered_map<int32_t, std::vector<float>> & output_map,
|
||||
const std::unordered_map<llama_seq_id, int32_t> & seq_to_batch_idx,
|
||||
ggml_backend_sched_t sched) {
|
||||
for (const auto & [seq_id, tensor] : tensor_map) {
|
||||
const int32_t idx = seq_to_batch_idx.at(seq_id);
|
||||
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
|
||||
output_map[idx].resize(ggml_nelements(tensor));
|
||||
ggml_backend_tensor_get_async(backend, tensor, output_map[idx].data(), 0, ggml_nbytes(tensor));
|
||||
}
|
||||
}
|
||||
|
||||
static void copy_tensor_async_token_ids(
|
||||
const std::unordered_map<llama_seq_id, ggml_tensor*> & tensor_map,
|
||||
std::unordered_map<int32_t, std::vector<llama_token>> & output_map,
|
||||
const std::unordered_map<llama_seq_id, int32_t> & seq_to_batch_idx,
|
||||
ggml_backend_sched_t sched) {
|
||||
for (const auto & [seq_id, tensor] : tensor_map) {
|
||||
const int32_t idx = seq_to_batch_idx.at(seq_id);
|
||||
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
|
||||
output_map[idx].resize(ggml_nelements(tensor));
|
||||
ggml_backend_tensor_get_async(backend, tensor, output_map[idx].data(), 0, ggml_nbytes(tensor));
|
||||
}
|
||||
}
|
||||
|
||||
int llama_context::decode(const llama_batch & batch_inp) {
|
||||
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
|
||||
|
||||
|
|
@ -1154,11 +1202,12 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
|
||||
// when computing embeddings, all tokens are output
|
||||
const bool output_all = cparams.embeddings;
|
||||
const bool has_backend_samplers = !samplers.empty();
|
||||
|
||||
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd,
|
||||
cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max,
|
||||
output_all,
|
||||
!samplers.empty())) {
|
||||
has_backend_samplers)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
|
@ -1312,56 +1361,27 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
||||
//}
|
||||
|
||||
std::unordered_map<llama_seq_id, int32_t> seq_to_idx;
|
||||
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
||||
if (ubatch.output[i]) {
|
||||
llama_seq_id seq_id = ubatch.seq_id[i][0];
|
||||
seq_to_idx[seq_id] = i;
|
||||
}
|
||||
if (has_backend_samplers) {
|
||||
const auto seq_to_batch_idx = build_seq_to_batch_idx(ubatch);
|
||||
|
||||
// If a backend sampler has sampled a token we only want to copy the
|
||||
// sampled tokens and avoid copying logits and probabilites.
|
||||
if (!res->t_sampled_tokens.empty()) {
|
||||
// async copy the sampled tokens from the backend to the host.
|
||||
copy_tensor_async_int(res->t_sampled_tokens, sampled_tokens_map, seq_to_batch_idx, sched.get());
|
||||
} else {
|
||||
// async copy the sampled logits/probs from the backend to the host.
|
||||
copy_tensor_async_floats(res->t_sampled_logits, sampled_logits_map, seq_to_batch_idx, sched.get());
|
||||
copy_tensor_async_floats(res->t_sampled_probs, sampled_probs_map, seq_to_batch_idx, sched.get());
|
||||
}
|
||||
|
||||
// extract sampled tokens
|
||||
for (const auto & [seq_id, t_token] : res->t_sampled_tokens) {
|
||||
auto idx_it = seq_to_idx.find(seq_id);
|
||||
GGML_ASSERT(idx_it != seq_to_idx.end());
|
||||
const int32_t idx = idx_it->second;
|
||||
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched.get(), t_token);
|
||||
ggml_backend_tensor_get_async(backend, t_token, &sampled_tokens_map[idx], 0, sizeof(llama_token));
|
||||
}
|
||||
// async copy the filtered token ids from the backend to the host.
|
||||
// These are needed for:
|
||||
// 1) Backend dist sampler to map indices to vocab token ids.
|
||||
// 2) CPU samplers to associate filtered logits with their token ids.
|
||||
copy_tensor_async_token_ids(res->t_sampled_token_ids, sampled_token_ids_map, seq_to_batch_idx, sched.get());
|
||||
|
||||
for (const auto & [seq_id, t_ids] : res->t_sampled_token_ids) {
|
||||
auto idx_it = seq_to_idx.find(seq_id);
|
||||
GGML_ASSERT(idx_it != seq_to_idx.end());
|
||||
const int32_t idx = idx_it->second;
|
||||
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched.get(), t_ids);
|
||||
sampled_token_ids_map[idx].resize(ggml_nelements(t_ids));
|
||||
ggml_backend_tensor_get_async(backend, t_ids, sampled_token_ids_map[idx].data(), 0, ggml_nbytes(t_ids));
|
||||
}
|
||||
|
||||
if (res->t_sampled_tokens.empty()) {
|
||||
for (const auto & [seq_id, t_logits] : res->t_sampled_logits) {
|
||||
auto idx_it = seq_to_idx.find(seq_id);
|
||||
GGML_ASSERT(idx_it != seq_to_idx.end());
|
||||
const int32_t idx = idx_it->second;
|
||||
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
|
||||
sampled_logits_map[idx].resize(ggml_nelements(t_logits));
|
||||
ggml_backend_tensor_get_async(backend, t_logits, sampled_logits_map[idx].data(), 0, ggml_nbytes(t_logits));
|
||||
}
|
||||
|
||||
// extract sampled probabilities
|
||||
for (const auto & [seq_id, t_probs] : res->t_sampled_probs) {
|
||||
auto idx_it = seq_to_idx.find(seq_id);
|
||||
GGML_ASSERT(idx_it != seq_to_idx.end());
|
||||
const int32_t idx = idx_it->second;
|
||||
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched.get(), t_probs);
|
||||
sampled_probs_map[idx].resize(ggml_nelements(t_probs));
|
||||
ggml_backend_tensor_get_async(backend, t_probs, sampled_probs_map[idx].data(), 0, ggml_nbytes(t_probs));
|
||||
}
|
||||
}
|
||||
|
||||
backend_has_sampled = !res->t_sampled_tokens.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty();
|
||||
|
||||
if (!backend_has_sampled) {
|
||||
} else {
|
||||
auto * t_logits = res->get_logits();
|
||||
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue