sampling : simplify backend sampling logic decode

This commit tries to simplify the backend sampling logic in
llama_context::decode.
This commit is contained in:
Daniel Bevenius 2025-11-19 09:31:33 +01:00
parent 51fee29822
commit 7e98ebcc6b
No known key found for this signature in database
1 changed files with 67 additions and 47 deletions

View File

@ -1133,6 +1133,54 @@ int llama_context::encode(const llama_batch & batch_inp) {
return 0; return 0;
} }
static std::unordered_map<llama_seq_id, int32_t> build_seq_to_batch_idx(const llama_ubatch & ubatch) {
std::unordered_map<llama_seq_id, int32_t> seq_to_batch_idx;
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
if (ubatch.output[i]) {
seq_to_batch_idx[ubatch.seq_id[i][0]] = i;
}
}
return seq_to_batch_idx;
}
static void copy_tensor_async_int(
const std::unordered_map<llama_seq_id, ggml_tensor*> & tensor_map,
std::unordered_map<int32_t, llama_token> & output_map,
const std::unordered_map<llama_seq_id, int32_t> & seq_to_batch_idx,
ggml_backend_sched_t sched) {
for (const auto & [seq_id, tensor] : tensor_map) {
const int32_t idx = seq_to_batch_idx.at(seq_id);
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
ggml_backend_tensor_get_async(backend, tensor, &output_map[idx], 0, sizeof(output_map[idx]));
}
}
static void copy_tensor_async_floats(
const std::unordered_map<llama_seq_id, ggml_tensor*> & tensor_map,
std::unordered_map<int32_t, std::vector<float>> & output_map,
const std::unordered_map<llama_seq_id, int32_t> & seq_to_batch_idx,
ggml_backend_sched_t sched) {
for (const auto & [seq_id, tensor] : tensor_map) {
const int32_t idx = seq_to_batch_idx.at(seq_id);
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
output_map[idx].resize(ggml_nelements(tensor));
ggml_backend_tensor_get_async(backend, tensor, output_map[idx].data(), 0, ggml_nbytes(tensor));
}
}
static void copy_tensor_async_token_ids(
const std::unordered_map<llama_seq_id, ggml_tensor*> & tensor_map,
std::unordered_map<int32_t, std::vector<llama_token>> & output_map,
const std::unordered_map<llama_seq_id, int32_t> & seq_to_batch_idx,
ggml_backend_sched_t sched) {
for (const auto & [seq_id, tensor] : tensor_map) {
const int32_t idx = seq_to_batch_idx.at(seq_id);
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor);
output_map[idx].resize(ggml_nelements(tensor));
ggml_backend_tensor_get_async(backend, tensor, output_map[idx].data(), 0, ggml_nbytes(tensor));
}
}
int llama_context::decode(const llama_batch & batch_inp) { int llama_context::decode(const llama_batch & batch_inp) {
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
@ -1154,11 +1202,12 @@ int llama_context::decode(const llama_batch & batch_inp) {
// when computing embeddings, all tokens are output // when computing embeddings, all tokens are output
const bool output_all = cparams.embeddings; const bool output_all = cparams.embeddings;
const bool has_backend_samplers = !samplers.empty();
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, if (!balloc->init(batch_inp, vocab, memory.get(), n_embd,
cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max,
output_all, output_all,
!samplers.empty())) { has_backend_samplers)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
return -1; return -1;
} }
@ -1312,56 +1361,27 @@ int llama_context::decode(const llama_batch & batch_inp) {
// ggml_graph_dump_dot(gf, NULL, "llama.dot"); // ggml_graph_dump_dot(gf, NULL, "llama.dot");
//} //}
std::unordered_map<llama_seq_id, int32_t> seq_to_idx; if (has_backend_samplers) {
for (uint32_t i = 0; i < ubatch.n_tokens; i++) { const auto seq_to_batch_idx = build_seq_to_batch_idx(ubatch);
if (ubatch.output[i]) {
llama_seq_id seq_id = ubatch.seq_id[i][0]; // If a backend sampler has sampled a token we only want to copy the
seq_to_idx[seq_id] = i; // sampled tokens and avoid copying logits and probabilites.
} if (!res->t_sampled_tokens.empty()) {
// async copy the sampled tokens from the backend to the host.
copy_tensor_async_int(res->t_sampled_tokens, sampled_tokens_map, seq_to_batch_idx, sched.get());
} else {
// async copy the sampled logits/probs from the backend to the host.
copy_tensor_async_floats(res->t_sampled_logits, sampled_logits_map, seq_to_batch_idx, sched.get());
copy_tensor_async_floats(res->t_sampled_probs, sampled_probs_map, seq_to_batch_idx, sched.get());
} }
// extract sampled tokens // async copy the filtered token ids from the backend to the host.
for (const auto & [seq_id, t_token] : res->t_sampled_tokens) { // These are needed for:
auto idx_it = seq_to_idx.find(seq_id); // 1) Backend dist sampler to map indices to vocab token ids.
GGML_ASSERT(idx_it != seq_to_idx.end()); // 2) CPU samplers to associate filtered logits with their token ids.
const int32_t idx = idx_it->second; copy_tensor_async_token_ids(res->t_sampled_token_ids, sampled_token_ids_map, seq_to_batch_idx, sched.get());
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched.get(), t_token);
ggml_backend_tensor_get_async(backend, t_token, &sampled_tokens_map[idx], 0, sizeof(llama_token));
}
for (const auto & [seq_id, t_ids] : res->t_sampled_token_ids) { } else {
auto idx_it = seq_to_idx.find(seq_id);
GGML_ASSERT(idx_it != seq_to_idx.end());
const int32_t idx = idx_it->second;
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched.get(), t_ids);
sampled_token_ids_map[idx].resize(ggml_nelements(t_ids));
ggml_backend_tensor_get_async(backend, t_ids, sampled_token_ids_map[idx].data(), 0, ggml_nbytes(t_ids));
}
if (res->t_sampled_tokens.empty()) {
for (const auto & [seq_id, t_logits] : res->t_sampled_logits) {
auto idx_it = seq_to_idx.find(seq_id);
GGML_ASSERT(idx_it != seq_to_idx.end());
const int32_t idx = idx_it->second;
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
sampled_logits_map[idx].resize(ggml_nelements(t_logits));
ggml_backend_tensor_get_async(backend, t_logits, sampled_logits_map[idx].data(), 0, ggml_nbytes(t_logits));
}
// extract sampled probabilities
for (const auto & [seq_id, t_probs] : res->t_sampled_probs) {
auto idx_it = seq_to_idx.find(seq_id);
GGML_ASSERT(idx_it != seq_to_idx.end());
const int32_t idx = idx_it->second;
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched.get(), t_probs);
sampled_probs_map[idx].resize(ggml_nelements(t_probs));
ggml_backend_tensor_get_async(backend, t_probs, sampled_probs_map[idx].data(), 0, ggml_nbytes(t_probs));
}
}
backend_has_sampled = !res->t_sampled_tokens.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty();
if (!backend_has_sampled) {
auto * t_logits = res->get_logits(); auto * t_logits = res->get_logits();
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;