sampling : fix backend temp sampler for zero temperature
This commit fixes the implementation of the temperature-based sampler for the case when the temperature is set to zero. This now correctly selects the most probable token by masking out all other tokens in the logits.
This commit is contained in:
parent
988261b18d
commit
739b597804
|
|
@ -1603,7 +1603,46 @@ static void llama_sampler_temp_backend_apply(
|
|||
auto * ctx_data = (llama_sampler_temp *) smpl->ctx;
|
||||
|
||||
if (ctx_data->temp <= 0.0f) {
|
||||
// TODO: this is incorrect - find the most probable token instead
|
||||
// Find the most probable token index.
|
||||
struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits);
|
||||
ggml_set_name(max_idx, "temp_max_idx");
|
||||
|
||||
// Reshape logits to 2D so we can use get_rows.
|
||||
struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
|
||||
ggml_set_name(logits_rows, "temp_logits_rows");
|
||||
|
||||
// Get the max logit value.
|
||||
struct ggml_tensor * max_logit = ggml_get_rows(ctx, logits_rows, max_idx);
|
||||
ggml_set_name(max_logit, "temp_max_logit");
|
||||
|
||||
// Repeat max_logit to match logits shape for element-wise operations.
|
||||
struct ggml_tensor * max_logit_repeated = ggml_repeat(ctx, max_logit, data->logits);
|
||||
ggml_set_name(max_logit_repeated, "temp_max_logit_repeated");
|
||||
|
||||
// Compute diff = max - logits.
|
||||
// At max_idx position this value will be zero, and positive elsewhere.
|
||||
struct ggml_tensor * diff = ggml_sub(ctx, max_logit_repeated, data->logits);
|
||||
ggml_set_name(diff, "temp_diff");
|
||||
|
||||
// Subtract small epsilon to make max position negative.
|
||||
// This ensures ggml_step returns 0 at max across all backends.
|
||||
struct ggml_tensor * diff_eps = ggml_scale_bias(ctx, diff, 1.0f, -1e-6f);
|
||||
ggml_set_name(diff_eps, "temp_diff_eps");
|
||||
|
||||
// Create mask: max position gets 0, everything else gets 1.
|
||||
struct ggml_tensor * mask = ggml_step(ctx, diff_eps);
|
||||
ggml_set_name(mask, "temp_mask");
|
||||
|
||||
// Convert mask to bias: -1e9 for non-max, 0 for max
|
||||
const float large_val = 1e9f;
|
||||
struct ggml_tensor * bias = ggml_scale_bias(ctx, mask, -large_val, 0.0f);
|
||||
ggml_set_name(bias, "temp_bias");
|
||||
|
||||
// Add the bias to logits to mask out non-max tokens.
|
||||
data->logits = ggml_add(ctx, data->logits, bias);
|
||||
ggml_set_name(data->logits, "temp_zero_logits");
|
||||
|
||||
ggml_build_forward_expand(gf, data->logits);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -250,6 +250,15 @@ struct test_model_context {
|
|||
return piece;
|
||||
}
|
||||
|
||||
void reset() {
|
||||
if (ctx) {
|
||||
llama_free(ctx);
|
||||
ctx = nullptr;
|
||||
}
|
||||
seq_positions.clear();
|
||||
last_batch_info.clear();
|
||||
}
|
||||
|
||||
void cleanup() {
|
||||
if (ctx) {
|
||||
llama_free(ctx);
|
||||
|
|
@ -360,36 +369,96 @@ static void test_backend_top_k_sampling(const char * model_path) {
|
|||
static void test_backend_temp_sampling(const char * model_path) {
|
||||
test_model_context test_ctx;
|
||||
|
||||
const float temp_0 = 0.8f;
|
||||
struct llama_sampler_chain_params backend_chain_params_0 = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * backend_sampler_chain_0 = llama_sampler_chain_init(backend_chain_params_0);
|
||||
llama_sampler_chain_add(backend_sampler_chain_0, llama_sampler_init_temp(temp_0));
|
||||
|
||||
const float temp_1 = 0.1f;
|
||||
struct llama_sampler_chain_params backend_chain_params_1 = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * backend_sampler_chain_1 = llama_sampler_chain_init(backend_chain_params_1);
|
||||
llama_sampler_chain_add(backend_sampler_chain_1, llama_sampler_init_temp(temp_1));
|
||||
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {
|
||||
{ 0, backend_sampler_chain_0 },
|
||||
{ 1, backend_sampler_chain_1 }
|
||||
};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!test_ctx.decode({{0, "Some where over"}, {1, "Once upon a"}})) {
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
|
||||
// Verfify sequence 0
|
||||
{
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(0);
|
||||
const float temp_0 = 0.8f;
|
||||
struct llama_sampler_chain_params backend_chain_params_0 = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * backend_sampler_chain_0 = llama_sampler_chain_init(backend_chain_params_0);
|
||||
llama_sampler_chain_add(backend_sampler_chain_0, llama_sampler_init_temp(temp_0));
|
||||
|
||||
const float temp_1 = 0.1f;
|
||||
struct llama_sampler_chain_params backend_chain_params_1 = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * backend_sampler_chain_1 = llama_sampler_chain_init(backend_chain_params_1);
|
||||
llama_sampler_chain_add(backend_sampler_chain_1, llama_sampler_init_temp(temp_1));
|
||||
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {
|
||||
{ 0, backend_sampler_chain_0 },
|
||||
{ 1, backend_sampler_chain_1 }
|
||||
};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!test_ctx.decode({{0, "Some where over"}, {1, "Once upon a"}})) {
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
|
||||
// Verfify sequence 0
|
||||
{
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(0);
|
||||
int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
|
||||
GGML_ASSERT(n_logits == test_ctx.n_vocab);
|
||||
|
||||
// Sample from sequence 0 using CPU sampler
|
||||
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_dist(18));
|
||||
|
||||
llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx);
|
||||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
||||
printf("Sequence 0 sampled token id:%d, string: '%s'\n", token, token_str.c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
||||
llama_sampler_free(chain);
|
||||
}
|
||||
|
||||
|
||||
// Verfify sequence 1
|
||||
{
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(1);
|
||||
|
||||
// Sample from sequence 1 using CPU sampler
|
||||
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_dist(18));
|
||||
|
||||
llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx);
|
||||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
||||
printf("Sequence 1 sampled token id:%d, string: '%s'\n", token, token_str.c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
||||
llama_sampler_free(chain);
|
||||
}
|
||||
}
|
||||
|
||||
// lambda to testing non-positive temperature values.
|
||||
auto test_argmax_temp = [&](float temp) {
|
||||
printf("\nTesting temperature = %.1f\n", temp);
|
||||
|
||||
test_ctx.reset();
|
||||
|
||||
int seq_id = 0;
|
||||
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
|
||||
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_temp(temp));
|
||||
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {
|
||||
{ seq_id, backend_sampler_chain },
|
||||
};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!test_ctx.decode({{seq_id, "Once"}})) {
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
|
||||
int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
|
||||
GGML_ASSERT(n_logits == test_ctx.n_vocab);
|
||||
|
||||
// Sample from sequence 0 using CPU sampler
|
||||
// Sample from sequence using CPU sampler
|
||||
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_dist(18));
|
||||
|
|
@ -399,25 +468,22 @@ static void test_backend_temp_sampling(const char * model_path) {
|
|||
printf("Sequence 0 sampled token id:%d, string: '%s'\n", token, token_str.c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
||||
llama_sampler_free(chain);
|
||||
}
|
||||
|
||||
// Verfify sequence 1
|
||||
{
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(1);
|
||||
|
||||
// Sample from sequence 1 using CPU sampler
|
||||
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_dist(18));
|
||||
|
||||
llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx);
|
||||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
||||
printf("Sequence 1 sampled token id:%d, string: '%s'\n", token, token_str.c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
// Verify that only one logit is available and the rest are masked out
|
||||
float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx);
|
||||
std::vector<float> unmasked;
|
||||
for (int i = 0; i < n_logits; ++i) {
|
||||
if (logits[i] > -1e9f) {
|
||||
unmasked.push_back(logits[i]);
|
||||
}
|
||||
}
|
||||
GGML_ASSERT(unmasked.size() == 1);
|
||||
printf("Temperature %.1f test: unmasked size: %d\n", temp, (int)unmasked.size());
|
||||
|
||||
llama_sampler_free(chain);
|
||||
}
|
||||
};
|
||||
|
||||
test_argmax_temp(0.0f);
|
||||
test_argmax_temp(-1.0f);
|
||||
|
||||
printf("backend temp sampling test PASSED\n");
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue