sampling : fix backend temp sampling to use logits masking

This commit is contained in:
Daniel Bevenius 2025-12-04 09:39:20 +01:00
parent 10bd640aae
commit ac9e164714
No known key found for this signature in database
2 changed files with 53 additions and 20 deletions

View File

@ -1587,7 +1587,7 @@ static void llama_sampler_temp_free(struct llama_sampler * smpl) {
delete (llama_sampler_temp *) smpl->ctx; delete (llama_sampler_temp *) smpl->ctx;
} }
static void temp_sampling( static void llama_sampler_backend_temp_sampling(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_cgraph * gf, struct ggml_cgraph * gf,
struct llama_sampler_data * data, struct llama_sampler_data * data,
@ -1597,8 +1597,32 @@ static void temp_sampling(
struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits); struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits);
ggml_set_name(max_idx, "temp_max_idx"); ggml_set_name(max_idx, "temp_max_idx");
// Set the sampled token to the most probable token. // Reshape to 2D and so we can use get_rows.
data->sampled = max_idx; struct ggml_tensor * logits_2d = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
ggml_set_name(logits_2d, "temp_logits_2d");
struct ggml_tensor * max_logit = ggml_get_rows(ctx, logits_2d, max_idx);
ggml_set_name(max_logit, "temp_max_logit");
// Subtract the max_logit from all logits.
struct ggml_tensor * diff = ggml_sub(ctx, data->logits, max_logit);
ggml_set_name(diff, "temp_diff");
// Add small epsilon to make max position strictly positive.
struct ggml_tensor * diff_eps = ggml_scale_bias(ctx, diff, 1.0f, 1e-6f);
ggml_set_name(diff_eps, "temp_diff_eps");
// Create the mask for the max logit.
struct ggml_tensor * mask = ggml_step(ctx, diff_eps);
ggml_set_name(mask, "temp_mask");
// Create the bias.
const float large_val = 1e9f;
struct ggml_tensor * bias = ggml_scale_bias(ctx, mask, large_val, -large_val);
ggml_set_name(bias, "temp_bias");
// Add the bias to the logits.
data->logits = ggml_add(ctx, data->logits, bias);
ggml_build_forward_expand(gf, data->logits);
return; return;
} }
@ -1618,7 +1642,7 @@ static void llama_sampler_temp_backend_apply(
struct ggml_cgraph * gf, struct ggml_cgraph * gf,
struct llama_sampler_data * data) { struct llama_sampler_data * data) {
auto * ctx_data = (llama_sampler_temp *) smpl->ctx; auto * ctx_data = (llama_sampler_temp *) smpl->ctx;
temp_sampling(ctx, gf, data, ctx_data->temp); llama_sampler_backend_temp_sampling(ctx, gf, data, ctx_data->temp);
} }
static struct llama_sampler_i llama_sampler_temp_i = { static struct llama_sampler_i llama_sampler_temp_i = {
@ -1750,7 +1774,7 @@ static void llama_sampler_temp_ext_backend_apply(
// Revert to standard temperature scaling if delta or temp are non-positive. // Revert to standard temperature scaling if delta or temp are non-positive.
if (ctx_data->delta <= 0.0f || ctx_data->temp <= 0.0f) { if (ctx_data->delta <= 0.0f || ctx_data->temp <= 0.0f) {
temp_sampling(ctx, gf, data, ctx_data->temp); llama_sampler_backend_temp_sampling(ctx, gf, data, ctx_data->temp);
return; return;
} }

View File

@ -456,13 +456,17 @@ static void test_backend_temp_sampling(const char * model_path) {
int32_t batch_idx = test_ctx.idx_for_seq(seq_id); int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx);
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx); float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx);
GGML_ASSERT(logits == nullptr); uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); GGML_ASSERT(n_logits == (uint32_t) test_ctx.n_vocab);
GGML_ASSERT(n_logits == 0);
std::vector<float> masked_logits;
for (size_t i = 0; i < n_logits; ++i) {
if (logits[i] <= -1e9f) {
masked_logits.push_back(logits[i]);
}
}
GGML_ASSERT(masked_logits.size() == (size_t) test_ctx.n_vocab - 1);
}; };
test_argmax_temp(0.0f); test_argmax_temp(0.0f);
@ -531,22 +535,27 @@ static void test_backend_temp_ext_sampling(const char * model_path) {
int32_t batch_idx = test_ctx.idx_for_seq(seq_id); int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx); float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx);
uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
GGML_ASSERT(n_logits == (uint32_t) test_ctx.n_vocab);
if (temp <= 0.0f) { std::vector<float> masked_logits;
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); for (size_t i = 0; i < n_logits; ++i) {
GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx) == nullptr); if (logits[i] <= -1e9f) {
GGML_ASSERT(llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx) == 0); masked_logits.push_back(logits[i]);
}
}
if (temp <= 0.0f && delta >= 0.0f) {
GGML_ASSERT(masked_logits.size() == (size_t) test_ctx.n_vocab - 1);
} else { } else {
GGML_ASSERT(token == LLAMA_TOKEN_NULL); printf("masked logits size: %zu\n", masked_logits.size());
int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx); GGML_ASSERT(masked_logits.size() == 0);
GGML_ASSERT(n_logits == test_ctx.n_vocab);
} }
}; };
test_argmax_temp(0.0f, 0.3f, 1.0f); // Greedy (temp=0) test_argmax_temp(0.0f, 0.3f, 1.0f); // Greedy (temp=0)
test_argmax_temp(-1.0f, 0.3f, 2.0f); // Greedy (temp<0) test_argmax_temp(-1.0f, 0.3f, 2.0f); // Greedy (temp<0)
test_argmax_temp(0.8f, 0.0f, 2.0f); // Temperature scaling (should have scaled logits) test_argmax_temp(0.8f, 0.0f, 2.0f); // Temperature scaling
printf("backend temp_ext sampling test PASSED\n"); printf("backend temp_ext sampling test PASSED\n");