From 739b59780432d3c8621d07dfa206d19f48fd51d5 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Tue, 2 Dec 2025 09:03:08 +0100 Subject: [PATCH] sampling : fix backend temp sampler for zero temperature This commit fixes the implementation of the temperature-based sampler for the case when the temperature is set to zero. This now correctly selects the most probable token by masking out all other tokens in the logits. --- src/llama-sampling.cpp | 41 ++++++++- tests/test-backend-sampler.cpp | 152 +++++++++++++++++++++++---------- 2 files changed, 149 insertions(+), 44 deletions(-) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index db7f2770b5..589bb8e1a2 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1603,7 +1603,46 @@ static void llama_sampler_temp_backend_apply( auto * ctx_data = (llama_sampler_temp *) smpl->ctx; if (ctx_data->temp <= 0.0f) { - // TODO: this is incorrect - find the most probable token instead + // Find the most probable token index. + struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits); + ggml_set_name(max_idx, "temp_max_idx"); + + // Reshape logits to 2D so we can use get_rows. + struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]); + ggml_set_name(logits_rows, "temp_logits_rows"); + + // Get the max logit value. + struct ggml_tensor * max_logit = ggml_get_rows(ctx, logits_rows, max_idx); + ggml_set_name(max_logit, "temp_max_logit"); + + // Repeat max_logit to match logits shape for element-wise operations. + struct ggml_tensor * max_logit_repeated = ggml_repeat(ctx, max_logit, data->logits); + ggml_set_name(max_logit_repeated, "temp_max_logit_repeated"); + + // Compute diff = max - logits. + // At max_idx position this value will be zero, and positive elsewhere. + struct ggml_tensor * diff = ggml_sub(ctx, max_logit_repeated, data->logits); + ggml_set_name(diff, "temp_diff"); + + // Subtract small epsilon to make max position negative. + // This ensures ggml_step returns 0 at max across all backends. + struct ggml_tensor * diff_eps = ggml_scale_bias(ctx, diff, 1.0f, -1e-6f); + ggml_set_name(diff_eps, "temp_diff_eps"); + + // Create mask: max position gets 0, everything else gets 1. + struct ggml_tensor * mask = ggml_step(ctx, diff_eps); + ggml_set_name(mask, "temp_mask"); + + // Convert mask to bias: -1e9 for non-max, 0 for max + const float large_val = 1e9f; + struct ggml_tensor * bias = ggml_scale_bias(ctx, mask, -large_val, 0.0f); + ggml_set_name(bias, "temp_bias"); + + // Add the bias to logits to mask out non-max tokens. + data->logits = ggml_add(ctx, data->logits, bias); + ggml_set_name(data->logits, "temp_zero_logits"); + + ggml_build_forward_expand(gf, data->logits); return; } diff --git a/tests/test-backend-sampler.cpp b/tests/test-backend-sampler.cpp index f185cebe9d..d32fdf2fb7 100644 --- a/tests/test-backend-sampler.cpp +++ b/tests/test-backend-sampler.cpp @@ -250,6 +250,15 @@ struct test_model_context { return piece; } + void reset() { + if (ctx) { + llama_free(ctx); + ctx = nullptr; + } + seq_positions.clear(); + last_batch_info.clear(); + } + void cleanup() { if (ctx) { llama_free(ctx); @@ -360,36 +369,96 @@ static void test_backend_top_k_sampling(const char * model_path) { static void test_backend_temp_sampling(const char * model_path) { test_model_context test_ctx; - const float temp_0 = 0.8f; - struct llama_sampler_chain_params backend_chain_params_0 = llama_sampler_chain_default_params(); - struct llama_sampler * backend_sampler_chain_0 = llama_sampler_chain_init(backend_chain_params_0); - llama_sampler_chain_add(backend_sampler_chain_0, llama_sampler_init_temp(temp_0)); - - const float temp_1 = 0.1f; - struct llama_sampler_chain_params backend_chain_params_1 = llama_sampler_chain_default_params(); - struct llama_sampler * backend_sampler_chain_1 = llama_sampler_chain_init(backend_chain_params_1); - llama_sampler_chain_add(backend_sampler_chain_1, llama_sampler_init_temp(temp_1)); - - std::vector backend_sampler_configs = { - { 0, backend_sampler_chain_0 }, - { 1, backend_sampler_chain_1 } - }; - - if (!test_ctx.setup(model_path, backend_sampler_configs)) { - return; - } - - if (!test_ctx.decode({{0, "Some where over"}, {1, "Once upon a"}})) { - GGML_ASSERT(false && "Failed to decode token"); - } - - // Verfify sequence 0 { - int32_t batch_idx = test_ctx.idx_for_seq(0); + const float temp_0 = 0.8f; + struct llama_sampler_chain_params backend_chain_params_0 = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain_0 = llama_sampler_chain_init(backend_chain_params_0); + llama_sampler_chain_add(backend_sampler_chain_0, llama_sampler_init_temp(temp_0)); + + const float temp_1 = 0.1f; + struct llama_sampler_chain_params backend_chain_params_1 = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain_1 = llama_sampler_chain_init(backend_chain_params_1); + llama_sampler_chain_add(backend_sampler_chain_1, llama_sampler_init_temp(temp_1)); + + std::vector backend_sampler_configs = { + { 0, backend_sampler_chain_0 }, + { 1, backend_sampler_chain_1 } + }; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{0, "Some where over"}, {1, "Once upon a"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + // Verfify sequence 0 + { + int32_t batch_idx = test_ctx.idx_for_seq(0); + int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); + GGML_ASSERT(n_logits == test_ctx.n_vocab); + + // Sample from sequence 0 using CPU sampler + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * chain = llama_sampler_chain_init(chain_params); + llama_sampler_chain_add(chain, llama_sampler_init_dist(18)); + + llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("Sequence 0 sampled token id:%d, string: '%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + + llama_sampler_free(chain); + } + + + // Verfify sequence 1 + { + int32_t batch_idx = test_ctx.idx_for_seq(1); + + // Sample from sequence 1 using CPU sampler + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * chain = llama_sampler_chain_init(chain_params); + llama_sampler_chain_add(chain, llama_sampler_init_dist(18)); + + llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("Sequence 1 sampled token id:%d, string: '%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + + llama_sampler_free(chain); + } + } + + // lambda to testing non-positive temperature values. + auto test_argmax_temp = [&](float temp) { + printf("\nTesting temperature = %.1f\n", temp); + + test_ctx.reset(); + + int seq_id = 0; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_temp(temp)); + + std::vector backend_sampler_configs = { + { seq_id, backend_sampler_chain }, + }; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Once"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); GGML_ASSERT(n_logits == test_ctx.n_vocab); - // Sample from sequence 0 using CPU sampler + // Sample from sequence using CPU sampler struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); struct llama_sampler * chain = llama_sampler_chain_init(chain_params); llama_sampler_chain_add(chain, llama_sampler_init_dist(18)); @@ -399,25 +468,22 @@ static void test_backend_temp_sampling(const char * model_path) { printf("Sequence 0 sampled token id:%d, string: '%s'\n", token, token_str.c_str()); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); - llama_sampler_free(chain); - } - - // Verfify sequence 1 - { - int32_t batch_idx = test_ctx.idx_for_seq(1); - - // Sample from sequence 1 using CPU sampler - struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); - struct llama_sampler * chain = llama_sampler_chain_init(chain_params); - llama_sampler_chain_add(chain, llama_sampler_init_dist(18)); - - llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); - const std::string token_str = test_ctx.token_to_piece(token, false); - printf("Sequence 1 sampled token id:%d, string: '%s'\n", token, token_str.c_str()); - GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + // Verify that only one logit is available and the rest are masked out + float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx); + std::vector unmasked; + for (int i = 0; i < n_logits; ++i) { + if (logits[i] > -1e9f) { + unmasked.push_back(logits[i]); + } + } + GGML_ASSERT(unmasked.size() == 1); + printf("Temperature %.1f test: unmasked size: %d\n", temp, (int)unmasked.size()); llama_sampler_free(chain); - } + }; + + test_argmax_temp(0.0f); + test_argmax_temp(-1.0f); printf("backend temp sampling test PASSED\n");