diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 3a3931a23c..bca38ffa15 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1597,32 +1597,13 @@ static void llama_sampler_backend_temp_sampling( struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits); ggml_set_name(max_idx, "temp_max_idx"); - // Reshape to 2D and so we can use get_rows. - struct ggml_tensor * logits_2d = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]); - ggml_set_name(logits_2d, "temp_logits_2d"); - struct ggml_tensor * max_logit = ggml_get_rows(ctx, logits_2d, max_idx); - ggml_set_name(max_logit, "temp_max_logit"); + data->candidates = max_idx; - // Subtract the max_logit from all logits. - struct ggml_tensor * diff = ggml_sub(ctx, data->logits, max_logit); - ggml_set_name(diff, "temp_diff"); + struct ggml_tensor * logit = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]); - // Add small epsilon to make max position strictly positive. - struct ggml_tensor * diff_eps = ggml_scale_bias(ctx, diff, 1.0f, 1e-6f); - ggml_set_name(diff_eps, "temp_diff_eps"); - - // Create the mask for the max logit. - struct ggml_tensor * mask = ggml_step(ctx, diff_eps); - ggml_set_name(mask, "temp_mask"); - - // Create the bias. - const float large_val = 1e9f; - struct ggml_tensor * bias = ggml_scale_bias(ctx, mask, large_val, -large_val); - ggml_set_name(bias, "temp_bias"); - - // Add the bias to the logits. - data->logits = ggml_add(ctx, data->logits, bias); + data->logits = ggml_get_rows(ctx, logit, max_idx); ggml_build_forward_expand(gf, data->logits); + return; } diff --git a/tests/test-backend-sampler.cpp b/tests/test-backend-sampler.cpp index 6b11df3bcb..5ef5fa396c 100644 --- a/tests/test-backend-sampler.cpp +++ b/tests/test-backend-sampler.cpp @@ -456,17 +456,8 @@ static void test_backend_temp_sampling(const char * model_path) { int32_t batch_idx = test_ctx.idx_for_seq(seq_id); - float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx); uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); - GGML_ASSERT(n_logits == (uint32_t) test_ctx.n_vocab); - - std::vector masked_logits; - for (size_t i = 0; i < n_logits; ++i) { - if (logits[i] <= -1e9f) { - masked_logits.push_back(logits[i]); - } - } - GGML_ASSERT(masked_logits.size() == (size_t) test_ctx.n_vocab - 1); + GGML_ASSERT(n_logits == 1); }; test_argmax_temp(0.0f); @@ -535,21 +526,12 @@ static void test_backend_temp_ext_sampling(const char * model_path) { int32_t batch_idx = test_ctx.idx_for_seq(seq_id); - float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx); uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); - GGML_ASSERT(n_logits == (uint32_t) test_ctx.n_vocab); - std::vector masked_logits; - for (size_t i = 0; i < n_logits; ++i) { - if (logits[i] <= -1e9f) { - masked_logits.push_back(logits[i]); - } - } if (temp <= 0.0f && delta >= 0.0f) { - GGML_ASSERT(masked_logits.size() == (size_t) test_ctx.n_vocab - 1); + GGML_ASSERT(n_logits == 1); } else { - printf("masked logits size: %zu\n", masked_logits.size()); - GGML_ASSERT(masked_logits.size() == 0); + GGML_ASSERT(n_logits == (uint32_t) test_ctx.n_vocab); } };