diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index d40f324b84..b9739a0f66 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1595,14 +1595,12 @@ static void llama_sampler_temp_free(struct llama_sampler * smpl) { delete (llama_sampler_temp *) smpl->ctx; } -static void llama_sampler_temp_backend_apply( - struct llama_sampler * smpl, +static void temp_sampling( struct ggml_context * ctx, struct ggml_cgraph * gf, - struct llama_sampler_data * data) { - auto * ctx_data = (llama_sampler_temp *) smpl->ctx; - - if (ctx_data->temp <= 0.0f) { + struct llama_sampler_data * data, + float temp) { + if (temp <= 0.0f) { // Find the most probable token index. struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits); ggml_set_name(max_idx, "temp_max_idx"); @@ -1612,7 +1610,7 @@ static void llama_sampler_temp_backend_apply( return; } - struct ggml_tensor * scaled = ggml_scale(ctx, data->logits, 1.0f / ctx_data->temp); + struct ggml_tensor * scaled = ggml_scale(ctx, data->logits, 1.0f / temp); ggml_set_name(scaled, "temp_scaled"); // Make sure the scaled tensor is contiguous for subsequent operations @@ -1622,6 +1620,15 @@ static void llama_sampler_temp_backend_apply( ggml_build_forward_expand(gf, data->logits); } +static void llama_sampler_temp_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + auto * ctx_data = (llama_sampler_temp *) smpl->ctx; + temp_sampling(ctx, gf, data, ctx_data->temp); +} + static struct llama_sampler_i llama_sampler_temp_i = { /* .name = */ llama_sampler_temp_name, /* .accept = */ nullptr, @@ -1742,7 +1749,6 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) { delete (llama_sampler_temp_ext *) smpl->ctx; } -// TODO: deduplicate with llama_sampler_temp_backend_apply static void llama_sampler_temp_ext_backend_apply( struct llama_sampler * smpl, struct ggml_context * ctx, @@ -1750,21 +1756,60 @@ static void llama_sampler_temp_ext_backend_apply( struct llama_sampler_data * data) { auto * ctx_data = (llama_sampler_temp_ext *) smpl->ctx; - // TODO: implement - GGML_ASSERT(ctx_data->delta <= 0.0f && "not implemented"); - - if (ctx_data->temp <= 0.0f) { - // TODO: this is incorrect - find the most probable token instead + // Revert to standard temperature scaling if delta or temp are non-positive. + if (ctx_data->delta <= 0.0f || ctx_data->temp <= 0.0f) { + temp_sampling(ctx, gf, data, ctx_data->temp); return; } - struct ggml_tensor * scaled = ggml_scale(ctx, data->logits, 1.0f / ctx_data->temp); - ggml_set_name(scaled, "temp_scaled"); + // Calculate min_temp, max_temp, and max_entropy. + const float min_temp = std::max(0.0f, ctx_data->temp - ctx_data->delta); + const float max_temp = ctx_data->temp + ctx_data->delta; + const float max_entropy = logf(data->logits->ne[0]); - // Make sure the scaled tensor is contiguous for subsequent operations - data->logits = ggml_cont(ctx, scaled); - ggml_set_name(data->logits, "temp_scaled_logits"); + // Calculate the probabilities. + struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits); + ggml_set_name(probs, "temp_ext_softmax_probs"); + // Clamp probabilities to avoid log(0) which would give -inf + struct ggml_tensor * probs_clamped = ggml_clamp(ctx, probs, 1e-10f, 1.0f); + ggml_set_name(probs_clamped, "temp_ext_probs_clamped"); + + // Calculate the entropy, entropy = -Σ(p * log(p)). + struct ggml_tensor * log_probs = ggml_log(ctx, probs_clamped); + struct ggml_tensor * p_log_p = ggml_mul(ctx, probs_clamped, log_probs); + struct ggml_tensor * sum_p_log_p = ggml_sum(ctx, p_log_p); + struct ggml_tensor * entropy = ggml_scale(ctx, sum_p_log_p, -1.0f); + ggml_set_name(log_probs, "temp_ext_log_probs"); + ggml_set_name(p_log_p, "temp_ext_p_log_p"); + ggml_set_name(sum_p_log_p, "temp_ext_sum_p_log_p"); + ggml_set_name(entropy, "temp_ext_entropy"); + + // Normalize the entropy, norm_entropy = entropy / max_entropy + struct ggml_tensor * norm_entropy = ggml_scale(ctx, entropy, 1.0f / max_entropy); + ggml_set_name(norm_entropy, "temp_ext_norm_entropy"); + + // Calculate the dynamic temperature: + // dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent); + // + // Calculate powf(normalized_entropy, exponent) as + // norm_entropy^exponent = exp(exponent * log(norm_entropy)) + struct ggml_tensor * log_norm_entropy = ggml_log(ctx, norm_entropy); + struct ggml_tensor * scaled_log = ggml_scale(ctx, log_norm_entropy, ctx_data->exponent); + struct ggml_tensor * pow_entropy = ggml_exp(ctx, scaled_log); + // With pow_entropy computed we can now compute dyn_temp, scaling by + // (max_temp - min_temp) and then adding min_temp. + struct ggml_tensor * dyn_temp = ggml_scale_bias(ctx, pow_entropy, max_temp - min_temp, min_temp); + ggml_set_name(log_norm_entropy, "temp_ext_log_norm_entropy"); + ggml_set_name(scaled_log, "temp_ext_scaled_log"); + ggml_set_name(pow_entropy, "temp_ext_pow_entropy"); + ggml_set_name(dyn_temp, "temp_ext_dyn_temp"); + + // Scale the logits by the dynamic temperature + struct ggml_tensor * scaled_logits = ggml_div(ctx, data->logits, dyn_temp); + ggml_set_name(scaled_logits, "temp_ext_scaled_logits"); + + data->logits = scaled_logits; ggml_build_forward_expand(gf, data->logits); } @@ -1777,7 +1822,7 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = { /* .free = */ llama_sampler_temp_ext_free, /* .backend_init = */ nullptr, /* .backend_accept = */ nullptr, - /* .backend_apply = */ nullptr, + /* .backend_apply = */ llama_sampler_temp_ext_backend_apply, /* .backend_set_input = */ nullptr, }; @@ -1797,12 +1842,6 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa } ); - const bool is_backend = delta <= 0.0f; - - if (is_backend) { - res->iface->backend_apply = llama_sampler_temp_ext_backend_apply; - } - return res; } diff --git a/tests/test-backend-sampler.cpp b/tests/test-backend-sampler.cpp index cea892236d..f56cce6350 100644 --- a/tests/test-backend-sampler.cpp +++ b/tests/test-backend-sampler.cpp @@ -472,6 +472,86 @@ static void test_backend_temp_sampling(const char * model_path) { } +static void test_backend_temp_ext_sampling(const char * model_path) { + test_model_context test_ctx; + + { + int seq_id = 0; + const float temp = 0.8f; + const float delta = 0.5f; + const float exponent = 1.5f; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_temp_ext(temp, delta, exponent)); + + std::vector backend_sampler_configs = { + { seq_id, backend_sampler_chain }, + }; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Once upon a"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + // Verify sequence 0 + { + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); + GGML_ASSERT(n_logits == test_ctx.n_vocab); + } + } + + test_ctx.reset(); + + // lambda to testing non-positive temp/delta/exponent values. + auto test_argmax_temp = [&](float temp, float delta, float exponent) { + printf("\nTesting temperature = %.1f, delta = %1.f, exponent = %1.f\n", temp, delta, exponent); + + test_ctx.reset(); + + int seq_id = 0; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_temp_ext(temp, delta, exponent)); + + std::vector backend_sampler_configs = { + { seq_id, backend_sampler_chain }, + }; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Once"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + + llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); + + if (temp <= 0.0f) { + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx) == nullptr); + GGML_ASSERT(llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx) == 0); + } else { + GGML_ASSERT(token == LLAMA_TOKEN_NULL); + int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); + GGML_ASSERT(n_logits == test_ctx.n_vocab); + } + }; + + test_argmax_temp(0.0f, 0.3f, 1.0f); // Greedy (temp=0) + test_argmax_temp(-1.0f, 0.3f, 2.0f); // Greedy (temp<0) + test_argmax_temp(0.8f, 0.0f, 2.0f); // Temperature scaling (should have scaled logits) + + printf("backend temp_ext sampling test PASSED\n"); + +} + static void test_backend_min_p_sampling(const char * model_path) { test_model_context test_ctx; @@ -1030,6 +1110,7 @@ static const backend_test_case BACKEND_TESTS[] = { { "greedy", test_backend_greedy_sampling, true }, { "logit_bias", test_backend_logit_bias_sampling, true }, { "temp", test_backend_temp_sampling, true }, + { "temp_ext", test_backend_temp_ext_sampling, true }, { "top_k", test_backend_top_k_sampling, true }, { "multi_sequence", test_backend_multi_sequence_sampling, true }, { "dist", test_backend_dist_sampling, true },