sampling : fix backend temp sampling to use logits masking
This commit is contained in:
parent
10bd640aae
commit
ac9e164714
|
|
@ -1587,7 +1587,7 @@ static void llama_sampler_temp_free(struct llama_sampler * smpl) {
|
||||||
delete (llama_sampler_temp *) smpl->ctx;
|
delete (llama_sampler_temp *) smpl->ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void temp_sampling(
|
static void llama_sampler_backend_temp_sampling(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_cgraph * gf,
|
struct ggml_cgraph * gf,
|
||||||
struct llama_sampler_data * data,
|
struct llama_sampler_data * data,
|
||||||
|
|
@ -1597,8 +1597,32 @@ static void temp_sampling(
|
||||||
struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits);
|
struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits);
|
||||||
ggml_set_name(max_idx, "temp_max_idx");
|
ggml_set_name(max_idx, "temp_max_idx");
|
||||||
|
|
||||||
// Set the sampled token to the most probable token.
|
// Reshape to 2D and so we can use get_rows.
|
||||||
data->sampled = max_idx;
|
struct ggml_tensor * logits_2d = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
|
||||||
|
ggml_set_name(logits_2d, "temp_logits_2d");
|
||||||
|
struct ggml_tensor * max_logit = ggml_get_rows(ctx, logits_2d, max_idx);
|
||||||
|
ggml_set_name(max_logit, "temp_max_logit");
|
||||||
|
|
||||||
|
// Subtract the max_logit from all logits.
|
||||||
|
struct ggml_tensor * diff = ggml_sub(ctx, data->logits, max_logit);
|
||||||
|
ggml_set_name(diff, "temp_diff");
|
||||||
|
|
||||||
|
// Add small epsilon to make max position strictly positive.
|
||||||
|
struct ggml_tensor * diff_eps = ggml_scale_bias(ctx, diff, 1.0f, 1e-6f);
|
||||||
|
ggml_set_name(diff_eps, "temp_diff_eps");
|
||||||
|
|
||||||
|
// Create the mask for the max logit.
|
||||||
|
struct ggml_tensor * mask = ggml_step(ctx, diff_eps);
|
||||||
|
ggml_set_name(mask, "temp_mask");
|
||||||
|
|
||||||
|
// Create the bias.
|
||||||
|
const float large_val = 1e9f;
|
||||||
|
struct ggml_tensor * bias = ggml_scale_bias(ctx, mask, large_val, -large_val);
|
||||||
|
ggml_set_name(bias, "temp_bias");
|
||||||
|
|
||||||
|
// Add the bias to the logits.
|
||||||
|
data->logits = ggml_add(ctx, data->logits, bias);
|
||||||
|
ggml_build_forward_expand(gf, data->logits);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1618,7 +1642,7 @@ static void llama_sampler_temp_backend_apply(
|
||||||
struct ggml_cgraph * gf,
|
struct ggml_cgraph * gf,
|
||||||
struct llama_sampler_data * data) {
|
struct llama_sampler_data * data) {
|
||||||
auto * ctx_data = (llama_sampler_temp *) smpl->ctx;
|
auto * ctx_data = (llama_sampler_temp *) smpl->ctx;
|
||||||
temp_sampling(ctx, gf, data, ctx_data->temp);
|
llama_sampler_backend_temp_sampling(ctx, gf, data, ctx_data->temp);
|
||||||
}
|
}
|
||||||
|
|
||||||
static struct llama_sampler_i llama_sampler_temp_i = {
|
static struct llama_sampler_i llama_sampler_temp_i = {
|
||||||
|
|
@ -1750,7 +1774,7 @@ static void llama_sampler_temp_ext_backend_apply(
|
||||||
|
|
||||||
// Revert to standard temperature scaling if delta or temp are non-positive.
|
// Revert to standard temperature scaling if delta or temp are non-positive.
|
||||||
if (ctx_data->delta <= 0.0f || ctx_data->temp <= 0.0f) {
|
if (ctx_data->delta <= 0.0f || ctx_data->temp <= 0.0f) {
|
||||||
temp_sampling(ctx, gf, data, ctx_data->temp);
|
llama_sampler_backend_temp_sampling(ctx, gf, data, ctx_data->temp);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -456,13 +456,17 @@ static void test_backend_temp_sampling(const char * model_path) {
|
||||||
|
|
||||||
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
|
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
|
||||||
|
|
||||||
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx);
|
|
||||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
|
||||||
|
|
||||||
float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx);
|
float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx);
|
||||||
GGML_ASSERT(logits == nullptr);
|
uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
|
||||||
int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
|
GGML_ASSERT(n_logits == (uint32_t) test_ctx.n_vocab);
|
||||||
GGML_ASSERT(n_logits == 0);
|
|
||||||
|
std::vector<float> masked_logits;
|
||||||
|
for (size_t i = 0; i < n_logits; ++i) {
|
||||||
|
if (logits[i] <= -1e9f) {
|
||||||
|
masked_logits.push_back(logits[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
GGML_ASSERT(masked_logits.size() == (size_t) test_ctx.n_vocab - 1);
|
||||||
};
|
};
|
||||||
|
|
||||||
test_argmax_temp(0.0f);
|
test_argmax_temp(0.0f);
|
||||||
|
|
@ -531,22 +535,27 @@ static void test_backend_temp_ext_sampling(const char * model_path) {
|
||||||
|
|
||||||
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
|
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
|
||||||
|
|
||||||
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx);
|
float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx);
|
||||||
|
uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
|
||||||
|
GGML_ASSERT(n_logits == (uint32_t) test_ctx.n_vocab);
|
||||||
|
|
||||||
if (temp <= 0.0f) {
|
std::vector<float> masked_logits;
|
||||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
for (size_t i = 0; i < n_logits; ++i) {
|
||||||
GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx) == nullptr);
|
if (logits[i] <= -1e9f) {
|
||||||
GGML_ASSERT(llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx) == 0);
|
masked_logits.push_back(logits[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (temp <= 0.0f && delta >= 0.0f) {
|
||||||
|
GGML_ASSERT(masked_logits.size() == (size_t) test_ctx.n_vocab - 1);
|
||||||
} else {
|
} else {
|
||||||
GGML_ASSERT(token == LLAMA_TOKEN_NULL);
|
printf("masked logits size: %zu\n", masked_logits.size());
|
||||||
int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
|
GGML_ASSERT(masked_logits.size() == 0);
|
||||||
GGML_ASSERT(n_logits == test_ctx.n_vocab);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
test_argmax_temp(0.0f, 0.3f, 1.0f); // Greedy (temp=0)
|
test_argmax_temp(0.0f, 0.3f, 1.0f); // Greedy (temp=0)
|
||||||
test_argmax_temp(-1.0f, 0.3f, 2.0f); // Greedy (temp<0)
|
test_argmax_temp(-1.0f, 0.3f, 2.0f); // Greedy (temp<0)
|
||||||
test_argmax_temp(0.8f, 0.0f, 2.0f); // Temperature scaling (should have scaled logits)
|
test_argmax_temp(0.8f, 0.0f, 2.0f); // Temperature scaling
|
||||||
|
|
||||||
printf("backend temp_ext sampling test PASSED\n");
|
printf("backend temp_ext sampling test PASSED\n");
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue