diff --git a/src/llama-context.cpp b/src/llama-context.cpp index c03a23e701..9ccd8f3998 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1470,16 +1470,14 @@ int llama_context::decode(const llama_batch & batch_inp) { const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev); const auto stride = n_vocab; - // If a backend sampler has sampled a token we only want to copy the - // sampled tokens and avoid copying logits and probabilites. - if (!res->t_sampled.empty()) { - // async copy the sampled tokens from the backend to the host. - copy_tensor_async_ints(res->t_sampled, sampling.sampled, sampling.sampled_size, seq_to_output_row, sched.get()); - } else { - // async copy the sampled logits/probs from the backend to the host. - copy_tensor_async_floats(res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, seq_to_output_row, sched.get()); - copy_tensor_async_floats(res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, seq_to_output_row, sched.get()); - } + // async copy the sampled tokens from the backend to the host. + copy_tensor_async_ints(res->t_sampled, sampling.sampled, sampling.sampled_size, seq_to_output_row, sched.get()); + + // async copy the sampled logits from the backend to the host. + copy_tensor_async_floats(res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, seq_to_output_row, sched.get()); + + // async copy the sampled probablities from the backend to the host. + copy_tensor_async_floats(res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, seq_to_output_row, sched.get()); // async copy the candidate token ids from the backend to the host. // These are needed by CPU samplers to map probability/logit indices to vocab token ids. diff --git a/tests/test-backend-sampler.cpp b/tests/test-backend-sampler.cpp index 2ed13688c9..ebbb6e039e 100644 --- a/tests/test-backend-sampler.cpp +++ b/tests/test-backend-sampler.cpp @@ -573,6 +573,62 @@ static void test_backend_logit_bias_sampling(const char * model_path) { GGML_ASSERT(backend_token == bias_token); } +// This test verifies that it is possible to have two different backend sampler, +// one that used the backend dist sampler, and another that uses CPU dist sampler. +static void test_backend_mixed_sampling(const char * model_path) { + test_model_context test_ctx; + + struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params(); + struct llama_sampler * sampler_chain_0 = llama_sampler_chain_init(chain_params_0); + llama_sampler_chain_add(sampler_chain_0, llama_sampler_backend_init_dist(88)); + + int k = 40; + struct llama_sampler_chain_params chain_params_1 = llama_sampler_chain_default_params(); + struct llama_sampler * sampler_chain_1 = llama_sampler_chain_init(chain_params_1); + llama_sampler_chain_add(sampler_chain_1, llama_sampler_backend_init_top_k(k)); + + std::vector backend_sampler_configs = { + { 0, sampler_chain_0 }, + { 1, sampler_chain_1 } + }; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + std::map prompts = { + {0, "Hello"}, + {1, "Some"} + }; + + if (!test_ctx.decode(prompts)) { + return; + } + + // Verfiy sequence 0 that used the dist backend sampler. + { + int32_t batch_idx = test_ctx.idx_for_seq(0); + llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("sampled token id=%d, string='%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + GGML_ASSERT(llama_get_backend_sampled_logits_ith(test_ctx.ctx, batch_idx) == nullptr); + GGML_ASSERT(llama_get_backend_sampled_logits_count_ith(test_ctx.ctx, batch_idx) == 0); + } + + // Verfiy sequence 0 that used the top-k backend sampler. + { + int32_t batch_idx = test_ctx.idx_for_seq(1); + float * logits = llama_get_backend_sampled_logits_ith(test_ctx.ctx, batch_idx); + GGML_ASSERT(logits != nullptr); + size_t n_logits = llama_get_backend_sampled_logits_count_ith(test_ctx.ctx, batch_idx); + GGML_ASSERT(n_logits == (size_t) k); + GGML_ASSERT(llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx) == LLAMA_TOKEN_NULL); + } + + printf("backend mixed sampling test PASSED\n"); +} + static void test_backend_set_sampler(const char * model_path) { test_model_context test_ctx; @@ -695,6 +751,7 @@ static const backend_test_case BACKEND_TESTS[] = { { "dist_and_cpu", test_backend_dist_sampling_and_cpu, true }, { "set_sampler", test_backend_set_sampler, true }, { "max_outputs", test_backend_max_outputs, true }, + { "mixed", test_backend_mixed_sampling, true }, }; struct backend_cli_args {