sampling : fix copying both sampled tokens and logits/probs from backend

This commit fixes the issue where both sampled tokens and logits/probs
were not being copied correctly from the backend to the host when
multiple backend samplers were used.

A test for this scenario has also been added to ensure that both types
of data are copied correctly when different backend samplers are
employed.
This commit is contained in:
Daniel Bevenius 2025-11-23 13:08:08 +01:00
parent ae23d2d2c1
commit 9e273f7aa4
No known key found for this signature in database
2 changed files with 65 additions and 10 deletions

View File

@ -1470,16 +1470,14 @@ int llama_context::decode(const llama_batch & batch_inp) {
const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev); const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev);
const auto stride = n_vocab; const auto stride = n_vocab;
// If a backend sampler has sampled a token we only want to copy the // async copy the sampled tokens from the backend to the host.
// sampled tokens and avoid copying logits and probabilites. copy_tensor_async_ints(res->t_sampled, sampling.sampled, sampling.sampled_size, seq_to_output_row, sched.get());
if (!res->t_sampled.empty()) {
// async copy the sampled tokens from the backend to the host. // async copy the sampled logits from the backend to the host.
copy_tensor_async_ints(res->t_sampled, sampling.sampled, sampling.sampled_size, seq_to_output_row, sched.get()); copy_tensor_async_floats(res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, seq_to_output_row, sched.get());
} else {
// async copy the sampled logits/probs from the backend to the host. // async copy the sampled probablities from the backend to the host.
copy_tensor_async_floats(res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, seq_to_output_row, sched.get()); copy_tensor_async_floats(res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, seq_to_output_row, sched.get());
copy_tensor_async_floats(res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, seq_to_output_row, sched.get());
}
// async copy the candidate token ids from the backend to the host. // async copy the candidate token ids from the backend to the host.
// These are needed by CPU samplers to map probability/logit indices to vocab token ids. // These are needed by CPU samplers to map probability/logit indices to vocab token ids.

View File

@ -573,6 +573,62 @@ static void test_backend_logit_bias_sampling(const char * model_path) {
GGML_ASSERT(backend_token == bias_token); GGML_ASSERT(backend_token == bias_token);
} }
// This test verifies that it is possible to have two different backend sampler,
// one that used the backend dist sampler, and another that uses CPU dist sampler.
static void test_backend_mixed_sampling(const char * model_path) {
test_model_context test_ctx;
struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
struct llama_sampler * sampler_chain_0 = llama_sampler_chain_init(chain_params_0);
llama_sampler_chain_add(sampler_chain_0, llama_sampler_backend_init_dist(88));
int k = 40;
struct llama_sampler_chain_params chain_params_1 = llama_sampler_chain_default_params();
struct llama_sampler * sampler_chain_1 = llama_sampler_chain_init(chain_params_1);
llama_sampler_chain_add(sampler_chain_1, llama_sampler_backend_init_top_k(k));
std::vector<llama_sampler_seq_config> backend_sampler_configs = {
{ 0, sampler_chain_0 },
{ 1, sampler_chain_1 }
};
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
return;
}
std::map<llama_seq_id, std::string> prompts = {
{0, "Hello"},
{1, "Some"}
};
if (!test_ctx.decode(prompts)) {
return;
}
// Verfiy sequence 0 that used the dist backend sampler.
{
int32_t batch_idx = test_ctx.idx_for_seq(0);
llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx);
const std::string token_str = test_ctx.token_to_piece(token, false);
printf("sampled token id=%d, string='%s'\n", token, token_str.c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
GGML_ASSERT(llama_get_backend_sampled_logits_ith(test_ctx.ctx, batch_idx) == nullptr);
GGML_ASSERT(llama_get_backend_sampled_logits_count_ith(test_ctx.ctx, batch_idx) == 0);
}
// Verfiy sequence 0 that used the top-k backend sampler.
{
int32_t batch_idx = test_ctx.idx_for_seq(1);
float * logits = llama_get_backend_sampled_logits_ith(test_ctx.ctx, batch_idx);
GGML_ASSERT(logits != nullptr);
size_t n_logits = llama_get_backend_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
GGML_ASSERT(n_logits == (size_t) k);
GGML_ASSERT(llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx) == LLAMA_TOKEN_NULL);
}
printf("backend mixed sampling test PASSED\n");
}
static void test_backend_set_sampler(const char * model_path) { static void test_backend_set_sampler(const char * model_path) {
test_model_context test_ctx; test_model_context test_ctx;
@ -695,6 +751,7 @@ static const backend_test_case BACKEND_TESTS[] = {
{ "dist_and_cpu", test_backend_dist_sampling_and_cpu, true }, { "dist_and_cpu", test_backend_dist_sampling_and_cpu, true },
{ "set_sampler", test_backend_set_sampler, true }, { "set_sampler", test_backend_set_sampler, true },
{ "max_outputs", test_backend_max_outputs, true }, { "max_outputs", test_backend_max_outputs, true },
{ "mixed", test_backend_mixed_sampling, true },
}; };
struct backend_cli_args { struct backend_cli_args {