sampling : implement temp_ext_backend sampling

This commit implements the apply function for the extended temperature
sampling.
This commit is contained in:
Daniel Bevenius 2025-12-02 17:26:04 +01:00
parent 2595818a68
commit aad5a6afd7
No known key found for this signature in database
2 changed files with 145 additions and 25 deletions

View File

@ -1595,14 +1595,12 @@ static void llama_sampler_temp_free(struct llama_sampler * smpl) {
delete (llama_sampler_temp *) smpl->ctx;
}
static void llama_sampler_temp_backend_apply(
struct llama_sampler * smpl,
static void temp_sampling(
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_data * data) {
auto * ctx_data = (llama_sampler_temp *) smpl->ctx;
if (ctx_data->temp <= 0.0f) {
struct llama_sampler_data * data,
float temp) {
if (temp <= 0.0f) {
// Find the most probable token index.
struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits);
ggml_set_name(max_idx, "temp_max_idx");
@ -1612,7 +1610,7 @@ static void llama_sampler_temp_backend_apply(
return;
}
struct ggml_tensor * scaled = ggml_scale(ctx, data->logits, 1.0f / ctx_data->temp);
struct ggml_tensor * scaled = ggml_scale(ctx, data->logits, 1.0f / temp);
ggml_set_name(scaled, "temp_scaled");
// Make sure the scaled tensor is contiguous for subsequent operations
@ -1622,6 +1620,15 @@ static void llama_sampler_temp_backend_apply(
ggml_build_forward_expand(gf, data->logits);
}
static void llama_sampler_temp_backend_apply(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_data * data) {
auto * ctx_data = (llama_sampler_temp *) smpl->ctx;
temp_sampling(ctx, gf, data, ctx_data->temp);
}
static struct llama_sampler_i llama_sampler_temp_i = {
/* .name = */ llama_sampler_temp_name,
/* .accept = */ nullptr,
@ -1742,7 +1749,6 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
delete (llama_sampler_temp_ext *) smpl->ctx;
}
// TODO: deduplicate with llama_sampler_temp_backend_apply
static void llama_sampler_temp_ext_backend_apply(
struct llama_sampler * smpl,
struct ggml_context * ctx,
@ -1750,21 +1756,60 @@ static void llama_sampler_temp_ext_backend_apply(
struct llama_sampler_data * data) {
auto * ctx_data = (llama_sampler_temp_ext *) smpl->ctx;
// TODO: implement
GGML_ASSERT(ctx_data->delta <= 0.0f && "not implemented");
if (ctx_data->temp <= 0.0f) {
// TODO: this is incorrect - find the most probable token instead
// Revert to standard temperature scaling if delta or temp are non-positive.
if (ctx_data->delta <= 0.0f || ctx_data->temp <= 0.0f) {
temp_sampling(ctx, gf, data, ctx_data->temp);
return;
}
struct ggml_tensor * scaled = ggml_scale(ctx, data->logits, 1.0f / ctx_data->temp);
ggml_set_name(scaled, "temp_scaled");
// Calculate min_temp, max_temp, and max_entropy.
const float min_temp = std::max(0.0f, ctx_data->temp - ctx_data->delta);
const float max_temp = ctx_data->temp + ctx_data->delta;
const float max_entropy = logf(data->logits->ne[0]);
// Make sure the scaled tensor is contiguous for subsequent operations
data->logits = ggml_cont(ctx, scaled);
ggml_set_name(data->logits, "temp_scaled_logits");
// Calculate the probabilities.
struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits);
ggml_set_name(probs, "temp_ext_softmax_probs");
// Clamp probabilities to avoid log(0) which would give -inf
struct ggml_tensor * probs_clamped = ggml_clamp(ctx, probs, 1e-10f, 1.0f);
ggml_set_name(probs_clamped, "temp_ext_probs_clamped");
// Calculate the entropy, entropy = -Σ(p * log(p)).
struct ggml_tensor * log_probs = ggml_log(ctx, probs_clamped);
struct ggml_tensor * p_log_p = ggml_mul(ctx, probs_clamped, log_probs);
struct ggml_tensor * sum_p_log_p = ggml_sum(ctx, p_log_p);
struct ggml_tensor * entropy = ggml_scale(ctx, sum_p_log_p, -1.0f);
ggml_set_name(log_probs, "temp_ext_log_probs");
ggml_set_name(p_log_p, "temp_ext_p_log_p");
ggml_set_name(sum_p_log_p, "temp_ext_sum_p_log_p");
ggml_set_name(entropy, "temp_ext_entropy");
// Normalize the entropy, norm_entropy = entropy / max_entropy
struct ggml_tensor * norm_entropy = ggml_scale(ctx, entropy, 1.0f / max_entropy);
ggml_set_name(norm_entropy, "temp_ext_norm_entropy");
// Calculate the dynamic temperature:
// dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent);
//
// Calculate powf(normalized_entropy, exponent) as
// norm_entropy^exponent = exp(exponent * log(norm_entropy))
struct ggml_tensor * log_norm_entropy = ggml_log(ctx, norm_entropy);
struct ggml_tensor * scaled_log = ggml_scale(ctx, log_norm_entropy, ctx_data->exponent);
struct ggml_tensor * pow_entropy = ggml_exp(ctx, scaled_log);
// With pow_entropy computed we can now compute dyn_temp, scaling by
// (max_temp - min_temp) and then adding min_temp.
struct ggml_tensor * dyn_temp = ggml_scale_bias(ctx, pow_entropy, max_temp - min_temp, min_temp);
ggml_set_name(log_norm_entropy, "temp_ext_log_norm_entropy");
ggml_set_name(scaled_log, "temp_ext_scaled_log");
ggml_set_name(pow_entropy, "temp_ext_pow_entropy");
ggml_set_name(dyn_temp, "temp_ext_dyn_temp");
// Scale the logits by the dynamic temperature
struct ggml_tensor * scaled_logits = ggml_div(ctx, data->logits, dyn_temp);
ggml_set_name(scaled_logits, "temp_ext_scaled_logits");
data->logits = scaled_logits;
ggml_build_forward_expand(gf, data->logits);
}
@ -1777,7 +1822,7 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = {
/* .free = */ llama_sampler_temp_ext_free,
/* .backend_init = */ nullptr,
/* .backend_accept = */ nullptr,
/* .backend_apply = */ nullptr,
/* .backend_apply = */ llama_sampler_temp_ext_backend_apply,
/* .backend_set_input = */ nullptr,
};
@ -1797,12 +1842,6 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
}
);
const bool is_backend = delta <= 0.0f;
if (is_backend) {
res->iface->backend_apply = llama_sampler_temp_ext_backend_apply;
}
return res;
}

View File

@ -472,6 +472,86 @@ static void test_backend_temp_sampling(const char * model_path) {
}
static void test_backend_temp_ext_sampling(const char * model_path) {
test_model_context test_ctx;
{
int seq_id = 0;
const float temp = 0.8f;
const float delta = 0.5f;
const float exponent = 1.5f;
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_temp_ext(temp, delta, exponent));
std::vector<llama_sampler_seq_config> backend_sampler_configs = {
{ seq_id, backend_sampler_chain },
};
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
return;
}
if (!test_ctx.decode({{seq_id, "Once upon a"}})) {
GGML_ASSERT(false && "Failed to decode token");
}
// Verify sequence 0
{
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
GGML_ASSERT(n_logits == test_ctx.n_vocab);
}
}
test_ctx.reset();
// lambda to testing non-positive temp/delta/exponent values.
auto test_argmax_temp = [&](float temp, float delta, float exponent) {
printf("\nTesting temperature = %.1f, delta = %1.f, exponent = %1.f\n", temp, delta, exponent);
test_ctx.reset();
int seq_id = 0;
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_temp_ext(temp, delta, exponent));
std::vector<llama_sampler_seq_config> backend_sampler_configs = {
{ seq_id, backend_sampler_chain },
};
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
return;
}
if (!test_ctx.decode({{seq_id, "Once"}})) {
GGML_ASSERT(false && "Failed to decode token");
}
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx);
if (temp <= 0.0f) {
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx) == nullptr);
GGML_ASSERT(llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx) == 0);
} else {
GGML_ASSERT(token == LLAMA_TOKEN_NULL);
int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
GGML_ASSERT(n_logits == test_ctx.n_vocab);
}
};
test_argmax_temp(0.0f, 0.3f, 1.0f); // Greedy (temp=0)
test_argmax_temp(-1.0f, 0.3f, 2.0f); // Greedy (temp<0)
test_argmax_temp(0.8f, 0.0f, 2.0f); // Temperature scaling (should have scaled logits)
printf("backend temp_ext sampling test PASSED\n");
}
static void test_backend_min_p_sampling(const char * model_path) {
test_model_context test_ctx;
@ -1030,6 +1110,7 @@ static const backend_test_case BACKEND_TESTS[] = {
{ "greedy", test_backend_greedy_sampling, true },
{ "logit_bias", test_backend_logit_bias_sampling, true },
{ "temp", test_backend_temp_sampling, true },
{ "temp_ext", test_backend_temp_ext_sampling, true },
{ "top_k", test_backend_top_k_sampling, true },
{ "multi_sequence", test_backend_multi_sequence_sampling, true },
{ "dist", test_backend_dist_sampling, true },