sampling : implement temp_ext_backend sampling
This commit implements the apply function for the extended temperature sampling.
This commit is contained in:
parent
2595818a68
commit
aad5a6afd7
|
|
@ -1595,14 +1595,12 @@ static void llama_sampler_temp_free(struct llama_sampler * smpl) {
|
||||||
delete (llama_sampler_temp *) smpl->ctx;
|
delete (llama_sampler_temp *) smpl->ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_sampler_temp_backend_apply(
|
static void temp_sampling(
|
||||||
struct llama_sampler * smpl,
|
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_cgraph * gf,
|
struct ggml_cgraph * gf,
|
||||||
struct llama_sampler_data * data) {
|
struct llama_sampler_data * data,
|
||||||
auto * ctx_data = (llama_sampler_temp *) smpl->ctx;
|
float temp) {
|
||||||
|
if (temp <= 0.0f) {
|
||||||
if (ctx_data->temp <= 0.0f) {
|
|
||||||
// Find the most probable token index.
|
// Find the most probable token index.
|
||||||
struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits);
|
struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits);
|
||||||
ggml_set_name(max_idx, "temp_max_idx");
|
ggml_set_name(max_idx, "temp_max_idx");
|
||||||
|
|
@ -1612,7 +1610,7 @@ static void llama_sampler_temp_backend_apply(
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * scaled = ggml_scale(ctx, data->logits, 1.0f / ctx_data->temp);
|
struct ggml_tensor * scaled = ggml_scale(ctx, data->logits, 1.0f / temp);
|
||||||
ggml_set_name(scaled, "temp_scaled");
|
ggml_set_name(scaled, "temp_scaled");
|
||||||
|
|
||||||
// Make sure the scaled tensor is contiguous for subsequent operations
|
// Make sure the scaled tensor is contiguous for subsequent operations
|
||||||
|
|
@ -1622,6 +1620,15 @@ static void llama_sampler_temp_backend_apply(
|
||||||
ggml_build_forward_expand(gf, data->logits);
|
ggml_build_forward_expand(gf, data->logits);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void llama_sampler_temp_backend_apply(
|
||||||
|
struct llama_sampler * smpl,
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_cgraph * gf,
|
||||||
|
struct llama_sampler_data * data) {
|
||||||
|
auto * ctx_data = (llama_sampler_temp *) smpl->ctx;
|
||||||
|
temp_sampling(ctx, gf, data, ctx_data->temp);
|
||||||
|
}
|
||||||
|
|
||||||
static struct llama_sampler_i llama_sampler_temp_i = {
|
static struct llama_sampler_i llama_sampler_temp_i = {
|
||||||
/* .name = */ llama_sampler_temp_name,
|
/* .name = */ llama_sampler_temp_name,
|
||||||
/* .accept = */ nullptr,
|
/* .accept = */ nullptr,
|
||||||
|
|
@ -1742,7 +1749,6 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
|
||||||
delete (llama_sampler_temp_ext *) smpl->ctx;
|
delete (llama_sampler_temp_ext *) smpl->ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: deduplicate with llama_sampler_temp_backend_apply
|
|
||||||
static void llama_sampler_temp_ext_backend_apply(
|
static void llama_sampler_temp_ext_backend_apply(
|
||||||
struct llama_sampler * smpl,
|
struct llama_sampler * smpl,
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
|
|
@ -1750,21 +1756,60 @@ static void llama_sampler_temp_ext_backend_apply(
|
||||||
struct llama_sampler_data * data) {
|
struct llama_sampler_data * data) {
|
||||||
auto * ctx_data = (llama_sampler_temp_ext *) smpl->ctx;
|
auto * ctx_data = (llama_sampler_temp_ext *) smpl->ctx;
|
||||||
|
|
||||||
// TODO: implement
|
// Revert to standard temperature scaling if delta or temp are non-positive.
|
||||||
GGML_ASSERT(ctx_data->delta <= 0.0f && "not implemented");
|
if (ctx_data->delta <= 0.0f || ctx_data->temp <= 0.0f) {
|
||||||
|
temp_sampling(ctx, gf, data, ctx_data->temp);
|
||||||
if (ctx_data->temp <= 0.0f) {
|
|
||||||
// TODO: this is incorrect - find the most probable token instead
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * scaled = ggml_scale(ctx, data->logits, 1.0f / ctx_data->temp);
|
// Calculate min_temp, max_temp, and max_entropy.
|
||||||
ggml_set_name(scaled, "temp_scaled");
|
const float min_temp = std::max(0.0f, ctx_data->temp - ctx_data->delta);
|
||||||
|
const float max_temp = ctx_data->temp + ctx_data->delta;
|
||||||
|
const float max_entropy = logf(data->logits->ne[0]);
|
||||||
|
|
||||||
// Make sure the scaled tensor is contiguous for subsequent operations
|
// Calculate the probabilities.
|
||||||
data->logits = ggml_cont(ctx, scaled);
|
struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits);
|
||||||
ggml_set_name(data->logits, "temp_scaled_logits");
|
ggml_set_name(probs, "temp_ext_softmax_probs");
|
||||||
|
|
||||||
|
// Clamp probabilities to avoid log(0) which would give -inf
|
||||||
|
struct ggml_tensor * probs_clamped = ggml_clamp(ctx, probs, 1e-10f, 1.0f);
|
||||||
|
ggml_set_name(probs_clamped, "temp_ext_probs_clamped");
|
||||||
|
|
||||||
|
// Calculate the entropy, entropy = -Σ(p * log(p)).
|
||||||
|
struct ggml_tensor * log_probs = ggml_log(ctx, probs_clamped);
|
||||||
|
struct ggml_tensor * p_log_p = ggml_mul(ctx, probs_clamped, log_probs);
|
||||||
|
struct ggml_tensor * sum_p_log_p = ggml_sum(ctx, p_log_p);
|
||||||
|
struct ggml_tensor * entropy = ggml_scale(ctx, sum_p_log_p, -1.0f);
|
||||||
|
ggml_set_name(log_probs, "temp_ext_log_probs");
|
||||||
|
ggml_set_name(p_log_p, "temp_ext_p_log_p");
|
||||||
|
ggml_set_name(sum_p_log_p, "temp_ext_sum_p_log_p");
|
||||||
|
ggml_set_name(entropy, "temp_ext_entropy");
|
||||||
|
|
||||||
|
// Normalize the entropy, norm_entropy = entropy / max_entropy
|
||||||
|
struct ggml_tensor * norm_entropy = ggml_scale(ctx, entropy, 1.0f / max_entropy);
|
||||||
|
ggml_set_name(norm_entropy, "temp_ext_norm_entropy");
|
||||||
|
|
||||||
|
// Calculate the dynamic temperature:
|
||||||
|
// dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent);
|
||||||
|
//
|
||||||
|
// Calculate powf(normalized_entropy, exponent) as
|
||||||
|
// norm_entropy^exponent = exp(exponent * log(norm_entropy))
|
||||||
|
struct ggml_tensor * log_norm_entropy = ggml_log(ctx, norm_entropy);
|
||||||
|
struct ggml_tensor * scaled_log = ggml_scale(ctx, log_norm_entropy, ctx_data->exponent);
|
||||||
|
struct ggml_tensor * pow_entropy = ggml_exp(ctx, scaled_log);
|
||||||
|
// With pow_entropy computed we can now compute dyn_temp, scaling by
|
||||||
|
// (max_temp - min_temp) and then adding min_temp.
|
||||||
|
struct ggml_tensor * dyn_temp = ggml_scale_bias(ctx, pow_entropy, max_temp - min_temp, min_temp);
|
||||||
|
ggml_set_name(log_norm_entropy, "temp_ext_log_norm_entropy");
|
||||||
|
ggml_set_name(scaled_log, "temp_ext_scaled_log");
|
||||||
|
ggml_set_name(pow_entropy, "temp_ext_pow_entropy");
|
||||||
|
ggml_set_name(dyn_temp, "temp_ext_dyn_temp");
|
||||||
|
|
||||||
|
// Scale the logits by the dynamic temperature
|
||||||
|
struct ggml_tensor * scaled_logits = ggml_div(ctx, data->logits, dyn_temp);
|
||||||
|
ggml_set_name(scaled_logits, "temp_ext_scaled_logits");
|
||||||
|
|
||||||
|
data->logits = scaled_logits;
|
||||||
ggml_build_forward_expand(gf, data->logits);
|
ggml_build_forward_expand(gf, data->logits);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1777,7 +1822,7 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = {
|
||||||
/* .free = */ llama_sampler_temp_ext_free,
|
/* .free = */ llama_sampler_temp_ext_free,
|
||||||
/* .backend_init = */ nullptr,
|
/* .backend_init = */ nullptr,
|
||||||
/* .backend_accept = */ nullptr,
|
/* .backend_accept = */ nullptr,
|
||||||
/* .backend_apply = */ nullptr,
|
/* .backend_apply = */ llama_sampler_temp_ext_backend_apply,
|
||||||
/* .backend_set_input = */ nullptr,
|
/* .backend_set_input = */ nullptr,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -1797,12 +1842,6 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
const bool is_backend = delta <= 0.0f;
|
|
||||||
|
|
||||||
if (is_backend) {
|
|
||||||
res->iface->backend_apply = llama_sampler_temp_ext_backend_apply;
|
|
||||||
}
|
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -472,6 +472,86 @@ static void test_backend_temp_sampling(const char * model_path) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void test_backend_temp_ext_sampling(const char * model_path) {
|
||||||
|
test_model_context test_ctx;
|
||||||
|
|
||||||
|
{
|
||||||
|
int seq_id = 0;
|
||||||
|
const float temp = 0.8f;
|
||||||
|
const float delta = 0.5f;
|
||||||
|
const float exponent = 1.5f;
|
||||||
|
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
|
||||||
|
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
|
||||||
|
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_temp_ext(temp, delta, exponent));
|
||||||
|
|
||||||
|
std::vector<llama_sampler_seq_config> backend_sampler_configs = {
|
||||||
|
{ seq_id, backend_sampler_chain },
|
||||||
|
};
|
||||||
|
|
||||||
|
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!test_ctx.decode({{seq_id, "Once upon a"}})) {
|
||||||
|
GGML_ASSERT(false && "Failed to decode token");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify sequence 0
|
||||||
|
{
|
||||||
|
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
|
||||||
|
int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
|
||||||
|
GGML_ASSERT(n_logits == test_ctx.n_vocab);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test_ctx.reset();
|
||||||
|
|
||||||
|
// lambda to testing non-positive temp/delta/exponent values.
|
||||||
|
auto test_argmax_temp = [&](float temp, float delta, float exponent) {
|
||||||
|
printf("\nTesting temperature = %.1f, delta = %1.f, exponent = %1.f\n", temp, delta, exponent);
|
||||||
|
|
||||||
|
test_ctx.reset();
|
||||||
|
|
||||||
|
int seq_id = 0;
|
||||||
|
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
|
||||||
|
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
|
||||||
|
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_temp_ext(temp, delta, exponent));
|
||||||
|
|
||||||
|
std::vector<llama_sampler_seq_config> backend_sampler_configs = {
|
||||||
|
{ seq_id, backend_sampler_chain },
|
||||||
|
};
|
||||||
|
|
||||||
|
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!test_ctx.decode({{seq_id, "Once"}})) {
|
||||||
|
GGML_ASSERT(false && "Failed to decode token");
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
|
||||||
|
|
||||||
|
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx);
|
||||||
|
|
||||||
|
if (temp <= 0.0f) {
|
||||||
|
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||||
|
GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx) == nullptr);
|
||||||
|
GGML_ASSERT(llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx) == 0);
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(token == LLAMA_TOKEN_NULL);
|
||||||
|
int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
|
||||||
|
GGML_ASSERT(n_logits == test_ctx.n_vocab);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
test_argmax_temp(0.0f, 0.3f, 1.0f); // Greedy (temp=0)
|
||||||
|
test_argmax_temp(-1.0f, 0.3f, 2.0f); // Greedy (temp<0)
|
||||||
|
test_argmax_temp(0.8f, 0.0f, 2.0f); // Temperature scaling (should have scaled logits)
|
||||||
|
|
||||||
|
printf("backend temp_ext sampling test PASSED\n");
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
static void test_backend_min_p_sampling(const char * model_path) {
|
static void test_backend_min_p_sampling(const char * model_path) {
|
||||||
test_model_context test_ctx;
|
test_model_context test_ctx;
|
||||||
|
|
||||||
|
|
@ -1030,6 +1110,7 @@ static const backend_test_case BACKEND_TESTS[] = {
|
||||||
{ "greedy", test_backend_greedy_sampling, true },
|
{ "greedy", test_backend_greedy_sampling, true },
|
||||||
{ "logit_bias", test_backend_logit_bias_sampling, true },
|
{ "logit_bias", test_backend_logit_bias_sampling, true },
|
||||||
{ "temp", test_backend_temp_sampling, true },
|
{ "temp", test_backend_temp_sampling, true },
|
||||||
|
{ "temp_ext", test_backend_temp_ext_sampling, true },
|
||||||
{ "top_k", test_backend_top_k_sampling, true },
|
{ "top_k", test_backend_top_k_sampling, true },
|
||||||
{ "multi_sequence", test_backend_multi_sequence_sampling, true },
|
{ "multi_sequence", test_backend_multi_sequence_sampling, true },
|
||||||
{ "dist", test_backend_dist_sampling, true },
|
{ "dist", test_backend_dist_sampling, true },
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue