diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 3ee8d3210f..3a3931a23c 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1587,7 +1587,7 @@ static void llama_sampler_temp_free(struct llama_sampler * smpl) { delete (llama_sampler_temp *) smpl->ctx; } -static void temp_sampling( +static void llama_sampler_backend_temp_sampling( struct ggml_context * ctx, struct ggml_cgraph * gf, struct llama_sampler_data * data, @@ -1597,8 +1597,32 @@ static void temp_sampling( struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits); ggml_set_name(max_idx, "temp_max_idx"); - // Set the sampled token to the most probable token. - data->sampled = max_idx; + // Reshape to 2D and so we can use get_rows. + struct ggml_tensor * logits_2d = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]); + ggml_set_name(logits_2d, "temp_logits_2d"); + struct ggml_tensor * max_logit = ggml_get_rows(ctx, logits_2d, max_idx); + ggml_set_name(max_logit, "temp_max_logit"); + + // Subtract the max_logit from all logits. + struct ggml_tensor * diff = ggml_sub(ctx, data->logits, max_logit); + ggml_set_name(diff, "temp_diff"); + + // Add small epsilon to make max position strictly positive. + struct ggml_tensor * diff_eps = ggml_scale_bias(ctx, diff, 1.0f, 1e-6f); + ggml_set_name(diff_eps, "temp_diff_eps"); + + // Create the mask for the max logit. + struct ggml_tensor * mask = ggml_step(ctx, diff_eps); + ggml_set_name(mask, "temp_mask"); + + // Create the bias. + const float large_val = 1e9f; + struct ggml_tensor * bias = ggml_scale_bias(ctx, mask, large_val, -large_val); + ggml_set_name(bias, "temp_bias"); + + // Add the bias to the logits. + data->logits = ggml_add(ctx, data->logits, bias); + ggml_build_forward_expand(gf, data->logits); return; } @@ -1618,7 +1642,7 @@ static void llama_sampler_temp_backend_apply( struct ggml_cgraph * gf, struct llama_sampler_data * data) { auto * ctx_data = (llama_sampler_temp *) smpl->ctx; - temp_sampling(ctx, gf, data, ctx_data->temp); + llama_sampler_backend_temp_sampling(ctx, gf, data, ctx_data->temp); } static struct llama_sampler_i llama_sampler_temp_i = { @@ -1750,7 +1774,7 @@ static void llama_sampler_temp_ext_backend_apply( // Revert to standard temperature scaling if delta or temp are non-positive. if (ctx_data->delta <= 0.0f || ctx_data->temp <= 0.0f) { - temp_sampling(ctx, gf, data, ctx_data->temp); + llama_sampler_backend_temp_sampling(ctx, gf, data, ctx_data->temp); return; } diff --git a/tests/test-backend-sampler.cpp b/tests/test-backend-sampler.cpp index f56cce6350..6b11df3bcb 100644 --- a/tests/test-backend-sampler.cpp +++ b/tests/test-backend-sampler.cpp @@ -456,13 +456,17 @@ static void test_backend_temp_sampling(const char * model_path) { int32_t batch_idx = test_ctx.idx_for_seq(seq_id); - llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); - GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); - float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx); - GGML_ASSERT(logits == nullptr); - int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); - GGML_ASSERT(n_logits == 0); + uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); + GGML_ASSERT(n_logits == (uint32_t) test_ctx.n_vocab); + + std::vector masked_logits; + for (size_t i = 0; i < n_logits; ++i) { + if (logits[i] <= -1e9f) { + masked_logits.push_back(logits[i]); + } + } + GGML_ASSERT(masked_logits.size() == (size_t) test_ctx.n_vocab - 1); }; test_argmax_temp(0.0f); @@ -531,22 +535,27 @@ static void test_backend_temp_ext_sampling(const char * model_path) { int32_t batch_idx = test_ctx.idx_for_seq(seq_id); - llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); + float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx); + uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); + GGML_ASSERT(n_logits == (uint32_t) test_ctx.n_vocab); - if (temp <= 0.0f) { - GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); - GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx) == nullptr); - GGML_ASSERT(llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx) == 0); + std::vector masked_logits; + for (size_t i = 0; i < n_logits; ++i) { + if (logits[i] <= -1e9f) { + masked_logits.push_back(logits[i]); + } + } + if (temp <= 0.0f && delta >= 0.0f) { + GGML_ASSERT(masked_logits.size() == (size_t) test_ctx.n_vocab - 1); } else { - GGML_ASSERT(token == LLAMA_TOKEN_NULL); - int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); - GGML_ASSERT(n_logits == test_ctx.n_vocab); + printf("masked logits size: %zu\n", masked_logits.size()); + GGML_ASSERT(masked_logits.size() == 0); } }; test_argmax_temp(0.0f, 0.3f, 1.0f); // Greedy (temp=0) test_argmax_temp(-1.0f, 0.3f, 2.0f); // Greedy (temp<0) - test_argmax_temp(0.8f, 0.0f, 2.0f); // Temperature scaling (should have scaled logits) + test_argmax_temp(0.8f, 0.0f, 2.0f); // Temperature scaling printf("backend temp_ext sampling test PASSED\n");