sampling : simplify temp sampling

This commit is contained in:
Georgi Gerganov 2025-12-04 14:23:02 +02:00
parent ac9e164714
commit fce571ee51
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 7 additions and 44 deletions

View File

@ -1597,32 +1597,13 @@ static void llama_sampler_backend_temp_sampling(
struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits);
ggml_set_name(max_idx, "temp_max_idx");
// Reshape to 2D and so we can use get_rows.
struct ggml_tensor * logits_2d = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
ggml_set_name(logits_2d, "temp_logits_2d");
struct ggml_tensor * max_logit = ggml_get_rows(ctx, logits_2d, max_idx);
ggml_set_name(max_logit, "temp_max_logit");
data->candidates = max_idx;
// Subtract the max_logit from all logits.
struct ggml_tensor * diff = ggml_sub(ctx, data->logits, max_logit);
ggml_set_name(diff, "temp_diff");
struct ggml_tensor * logit = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
// Add small epsilon to make max position strictly positive.
struct ggml_tensor * diff_eps = ggml_scale_bias(ctx, diff, 1.0f, 1e-6f);
ggml_set_name(diff_eps, "temp_diff_eps");
// Create the mask for the max logit.
struct ggml_tensor * mask = ggml_step(ctx, diff_eps);
ggml_set_name(mask, "temp_mask");
// Create the bias.
const float large_val = 1e9f;
struct ggml_tensor * bias = ggml_scale_bias(ctx, mask, large_val, -large_val);
ggml_set_name(bias, "temp_bias");
// Add the bias to the logits.
data->logits = ggml_add(ctx, data->logits, bias);
data->logits = ggml_get_rows(ctx, logit, max_idx);
ggml_build_forward_expand(gf, data->logits);
return;
}

View File

@ -456,17 +456,8 @@ static void test_backend_temp_sampling(const char * model_path) {
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx);
uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
GGML_ASSERT(n_logits == (uint32_t) test_ctx.n_vocab);
std::vector<float> masked_logits;
for (size_t i = 0; i < n_logits; ++i) {
if (logits[i] <= -1e9f) {
masked_logits.push_back(logits[i]);
}
}
GGML_ASSERT(masked_logits.size() == (size_t) test_ctx.n_vocab - 1);
GGML_ASSERT(n_logits == 1);
};
test_argmax_temp(0.0f);
@ -535,21 +526,12 @@ static void test_backend_temp_ext_sampling(const char * model_path) {
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx);
uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
GGML_ASSERT(n_logits == (uint32_t) test_ctx.n_vocab);
std::vector<float> masked_logits;
for (size_t i = 0; i < n_logits; ++i) {
if (logits[i] <= -1e9f) {
masked_logits.push_back(logits[i]);
}
}
if (temp <= 0.0f && delta >= 0.0f) {
GGML_ASSERT(masked_logits.size() == (size_t) test_ctx.n_vocab - 1);
GGML_ASSERT(n_logits == 1);
} else {
printf("masked logits size: %zu\n", masked_logits.size());
GGML_ASSERT(masked_logits.size() == 0);
GGML_ASSERT(n_logits == (uint32_t) test_ctx.n_vocab);
}
};