From db8972e2517be09f13c0377605174850db72780b Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 2 Dec 2025 11:53:29 +0100 Subject: [PATCH] squash! sampling : fix backend temp sampler for zero temperature This modifies the parent commit to simply return the most probably token instead of masking the logits. --- src/llama-sampling.cpp | 38 ++-------------------------------- tests/test-backend-sampler.cpp | 25 ++++------------------ 2 files changed, 6 insertions(+), 57 deletions(-) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 589bb8e1a2..d40f324b84 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1607,42 +1607,8 @@ static void llama_sampler_temp_backend_apply( struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits); ggml_set_name(max_idx, "temp_max_idx"); - // Reshape logits to 2D so we can use get_rows. - struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]); - ggml_set_name(logits_rows, "temp_logits_rows"); - - // Get the max logit value. - struct ggml_tensor * max_logit = ggml_get_rows(ctx, logits_rows, max_idx); - ggml_set_name(max_logit, "temp_max_logit"); - - // Repeat max_logit to match logits shape for element-wise operations. - struct ggml_tensor * max_logit_repeated = ggml_repeat(ctx, max_logit, data->logits); - ggml_set_name(max_logit_repeated, "temp_max_logit_repeated"); - - // Compute diff = max - logits. - // At max_idx position this value will be zero, and positive elsewhere. - struct ggml_tensor * diff = ggml_sub(ctx, max_logit_repeated, data->logits); - ggml_set_name(diff, "temp_diff"); - - // Subtract small epsilon to make max position negative. - // This ensures ggml_step returns 0 at max across all backends. - struct ggml_tensor * diff_eps = ggml_scale_bias(ctx, diff, 1.0f, -1e-6f); - ggml_set_name(diff_eps, "temp_diff_eps"); - - // Create mask: max position gets 0, everything else gets 1. - struct ggml_tensor * mask = ggml_step(ctx, diff_eps); - ggml_set_name(mask, "temp_mask"); - - // Convert mask to bias: -1e9 for non-max, 0 for max - const float large_val = 1e9f; - struct ggml_tensor * bias = ggml_scale_bias(ctx, mask, -large_val, 0.0f); - ggml_set_name(bias, "temp_bias"); - - // Add the bias to logits to mask out non-max tokens. - data->logits = ggml_add(ctx, data->logits, bias); - ggml_set_name(data->logits, "temp_zero_logits"); - - ggml_build_forward_expand(gf, data->logits); + // Set the sampled token to the most probable token. + data->sampled = max_idx; return; } diff --git a/tests/test-backend-sampler.cpp b/tests/test-backend-sampler.cpp index d32fdf2fb7..cea892236d 100644 --- a/tests/test-backend-sampler.cpp +++ b/tests/test-backend-sampler.cpp @@ -455,31 +455,14 @@ static void test_backend_temp_sampling(const char * model_path) { } int32_t batch_idx = test_ctx.idx_for_seq(seq_id); - int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); - GGML_ASSERT(n_logits == test_ctx.n_vocab); - // Sample from sequence using CPU sampler - struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); - struct llama_sampler * chain = llama_sampler_chain_init(chain_params); - llama_sampler_chain_add(chain, llama_sampler_init_dist(18)); - - llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); - const std::string token_str = test_ctx.token_to_piece(token, false); - printf("Sequence 0 sampled token id:%d, string: '%s'\n", token, token_str.c_str()); + llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); - // Verify that only one logit is available and the rest are masked out float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx); - std::vector unmasked; - for (int i = 0; i < n_logits; ++i) { - if (logits[i] > -1e9f) { - unmasked.push_back(logits[i]); - } - } - GGML_ASSERT(unmasked.size() == 1); - printf("Temperature %.1f test: unmasked size: %d\n", temp, (int)unmasked.size()); - - llama_sampler_free(chain); + GGML_ASSERT(logits == nullptr); + int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); + GGML_ASSERT(n_logits == 0); }; test_argmax_temp(0.0f);