diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 8485416c3e..15ca80a735 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1361,7 +1361,9 @@ int llama_context::decode(const llama_batch & batch_inp) { // ggml_graph_dump_dot(gf, NULL, "llama.dot"); //} - if (has_backend_samplers) { + const bool backend_has_sampled = !res->t_sampled_tokens.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty(); + + if (has_backend_samplers && backend_has_sampled) { const auto seq_to_batch_idx = build_seq_to_batch_idx(ubatch); // If a backend sampler has sampled a token we only want to copy the @@ -1381,7 +1383,9 @@ int llama_context::decode(const llama_batch & batch_inp) { // 2) CPU samplers to associate filtered logits with their token ids. copy_tensor_async_token_ids(res->t_sampled_token_ids, sampled_token_ids_map, seq_to_batch_idx, sched.get()); - } else { + } + + if (!backend_has_sampled) { auto * t_logits = res->get_logits(); auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;