diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 0b7f3adf9b..8485416c3e 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1133,6 +1133,54 @@ int llama_context::encode(const llama_batch & batch_inp) { return 0; } +static std::unordered_map build_seq_to_batch_idx(const llama_ubatch & ubatch) { + std::unordered_map seq_to_batch_idx; + for (uint32_t i = 0; i < ubatch.n_tokens; i++) { + if (ubatch.output[i]) { + seq_to_batch_idx[ubatch.seq_id[i][0]] = i; + } + } + return seq_to_batch_idx; +} + +static void copy_tensor_async_int( + const std::unordered_map & tensor_map, + std::unordered_map & output_map, + const std::unordered_map & seq_to_batch_idx, + ggml_backend_sched_t sched) { + for (const auto & [seq_id, tensor] : tensor_map) { + const int32_t idx = seq_to_batch_idx.at(seq_id); + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); + ggml_backend_tensor_get_async(backend, tensor, &output_map[idx], 0, sizeof(output_map[idx])); + } +} + +static void copy_tensor_async_floats( + const std::unordered_map & tensor_map, + std::unordered_map> & output_map, + const std::unordered_map & seq_to_batch_idx, + ggml_backend_sched_t sched) { + for (const auto & [seq_id, tensor] : tensor_map) { + const int32_t idx = seq_to_batch_idx.at(seq_id); + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); + output_map[idx].resize(ggml_nelements(tensor)); + ggml_backend_tensor_get_async(backend, tensor, output_map[idx].data(), 0, ggml_nbytes(tensor)); + } +} + +static void copy_tensor_async_token_ids( + const std::unordered_map & tensor_map, + std::unordered_map> & output_map, + const std::unordered_map & seq_to_batch_idx, + ggml_backend_sched_t sched) { + for (const auto & [seq_id, tensor] : tensor_map) { + const int32_t idx = seq_to_batch_idx.at(seq_id); + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); + output_map[idx].resize(ggml_nelements(tensor)); + ggml_backend_tensor_get_async(backend, tensor, output_map[idx].data(), 0, ggml_nbytes(tensor)); + } +} + int llama_context::decode(const llama_batch & batch_inp) { GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT @@ -1154,11 +1202,12 @@ int llama_context::decode(const llama_batch & batch_inp) { // when computing embeddings, all tokens are output const bool output_all = cparams.embeddings; + const bool has_backend_samplers = !samplers.empty(); if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all, - !samplers.empty())) { + has_backend_samplers)) { LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); return -1; } @@ -1312,56 +1361,27 @@ int llama_context::decode(const llama_batch & batch_inp) { // ggml_graph_dump_dot(gf, NULL, "llama.dot"); //} - std::unordered_map seq_to_idx; - for (uint32_t i = 0; i < ubatch.n_tokens; i++) { - if (ubatch.output[i]) { - llama_seq_id seq_id = ubatch.seq_id[i][0]; - seq_to_idx[seq_id] = i; - } - } + if (has_backend_samplers) { + const auto seq_to_batch_idx = build_seq_to_batch_idx(ubatch); - // extract sampled tokens - for (const auto & [seq_id, t_token] : res->t_sampled_tokens) { - auto idx_it = seq_to_idx.find(seq_id); - GGML_ASSERT(idx_it != seq_to_idx.end()); - const int32_t idx = idx_it->second; - ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched.get(), t_token); - ggml_backend_tensor_get_async(backend, t_token, &sampled_tokens_map[idx], 0, sizeof(llama_token)); - } - - for (const auto & [seq_id, t_ids] : res->t_sampled_token_ids) { - auto idx_it = seq_to_idx.find(seq_id); - GGML_ASSERT(idx_it != seq_to_idx.end()); - const int32_t idx = idx_it->second; - ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched.get(), t_ids); - sampled_token_ids_map[idx].resize(ggml_nelements(t_ids)); - ggml_backend_tensor_get_async(backend, t_ids, sampled_token_ids_map[idx].data(), 0, ggml_nbytes(t_ids)); - } - - if (res->t_sampled_tokens.empty()) { - for (const auto & [seq_id, t_logits] : res->t_sampled_logits) { - auto idx_it = seq_to_idx.find(seq_id); - GGML_ASSERT(idx_it != seq_to_idx.end()); - const int32_t idx = idx_it->second; - ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); - sampled_logits_map[idx].resize(ggml_nelements(t_logits)); - ggml_backend_tensor_get_async(backend, t_logits, sampled_logits_map[idx].data(), 0, ggml_nbytes(t_logits)); + // If a backend sampler has sampled a token we only want to copy the + // sampled tokens and avoid copying logits and probabilites. + if (!res->t_sampled_tokens.empty()) { + // async copy the sampled tokens from the backend to the host. + copy_tensor_async_int(res->t_sampled_tokens, sampled_tokens_map, seq_to_batch_idx, sched.get()); + } else { + // async copy the sampled logits/probs from the backend to the host. + copy_tensor_async_floats(res->t_sampled_logits, sampled_logits_map, seq_to_batch_idx, sched.get()); + copy_tensor_async_floats(res->t_sampled_probs, sampled_probs_map, seq_to_batch_idx, sched.get()); } - // extract sampled probabilities - for (const auto & [seq_id, t_probs] : res->t_sampled_probs) { - auto idx_it = seq_to_idx.find(seq_id); - GGML_ASSERT(idx_it != seq_to_idx.end()); - const int32_t idx = idx_it->second; - ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched.get(), t_probs); - sampled_probs_map[idx].resize(ggml_nelements(t_probs)); - ggml_backend_tensor_get_async(backend, t_probs, sampled_probs_map[idx].data(), 0, ggml_nbytes(t_probs)); - } - } + // async copy the filtered token ids from the backend to the host. + // These are needed for: + // 1) Backend dist sampler to map indices to vocab token ids. + // 2) CPU samplers to associate filtered logits with their token ids. + copy_tensor_async_token_ids(res->t_sampled_token_ids, sampled_token_ids_map, seq_to_batch_idx, sched.get()); - backend_has_sampled = !res->t_sampled_tokens.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty(); - - if (!backend_has_sampled) { + } else { auto * t_logits = res->get_logits(); auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;