squash! sampling : fix backend temp sampler for zero temperature
This modifies the parent commit to simply return the most probably token instead of masking the logits.
This commit is contained in:
parent
516af33ca6
commit
db8972e251
|
|
@ -1607,42 +1607,8 @@ static void llama_sampler_temp_backend_apply(
|
||||||
struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits);
|
struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits);
|
||||||
ggml_set_name(max_idx, "temp_max_idx");
|
ggml_set_name(max_idx, "temp_max_idx");
|
||||||
|
|
||||||
// Reshape logits to 2D so we can use get_rows.
|
// Set the sampled token to the most probable token.
|
||||||
struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
|
data->sampled = max_idx;
|
||||||
ggml_set_name(logits_rows, "temp_logits_rows");
|
|
||||||
|
|
||||||
// Get the max logit value.
|
|
||||||
struct ggml_tensor * max_logit = ggml_get_rows(ctx, logits_rows, max_idx);
|
|
||||||
ggml_set_name(max_logit, "temp_max_logit");
|
|
||||||
|
|
||||||
// Repeat max_logit to match logits shape for element-wise operations.
|
|
||||||
struct ggml_tensor * max_logit_repeated = ggml_repeat(ctx, max_logit, data->logits);
|
|
||||||
ggml_set_name(max_logit_repeated, "temp_max_logit_repeated");
|
|
||||||
|
|
||||||
// Compute diff = max - logits.
|
|
||||||
// At max_idx position this value will be zero, and positive elsewhere.
|
|
||||||
struct ggml_tensor * diff = ggml_sub(ctx, max_logit_repeated, data->logits);
|
|
||||||
ggml_set_name(diff, "temp_diff");
|
|
||||||
|
|
||||||
// Subtract small epsilon to make max position negative.
|
|
||||||
// This ensures ggml_step returns 0 at max across all backends.
|
|
||||||
struct ggml_tensor * diff_eps = ggml_scale_bias(ctx, diff, 1.0f, -1e-6f);
|
|
||||||
ggml_set_name(diff_eps, "temp_diff_eps");
|
|
||||||
|
|
||||||
// Create mask: max position gets 0, everything else gets 1.
|
|
||||||
struct ggml_tensor * mask = ggml_step(ctx, diff_eps);
|
|
||||||
ggml_set_name(mask, "temp_mask");
|
|
||||||
|
|
||||||
// Convert mask to bias: -1e9 for non-max, 0 for max
|
|
||||||
const float large_val = 1e9f;
|
|
||||||
struct ggml_tensor * bias = ggml_scale_bias(ctx, mask, -large_val, 0.0f);
|
|
||||||
ggml_set_name(bias, "temp_bias");
|
|
||||||
|
|
||||||
// Add the bias to logits to mask out non-max tokens.
|
|
||||||
data->logits = ggml_add(ctx, data->logits, bias);
|
|
||||||
ggml_set_name(data->logits, "temp_zero_logits");
|
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, data->logits);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -455,31 +455,14 @@ static void test_backend_temp_sampling(const char * model_path) {
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
|
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
|
||||||
int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
|
|
||||||
GGML_ASSERT(n_logits == test_ctx.n_vocab);
|
|
||||||
|
|
||||||
// Sample from sequence using CPU sampler
|
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx);
|
||||||
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
|
|
||||||
struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
|
|
||||||
llama_sampler_chain_add(chain, llama_sampler_init_dist(18));
|
|
||||||
|
|
||||||
llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx);
|
|
||||||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
|
||||||
printf("Sequence 0 sampled token id:%d, string: '%s'\n", token, token_str.c_str());
|
|
||||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||||
|
|
||||||
// Verify that only one logit is available and the rest are masked out
|
|
||||||
float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx);
|
float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx);
|
||||||
std::vector<float> unmasked;
|
GGML_ASSERT(logits == nullptr);
|
||||||
for (int i = 0; i < n_logits; ++i) {
|
int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
|
||||||
if (logits[i] > -1e9f) {
|
GGML_ASSERT(n_logits == 0);
|
||||||
unmasked.push_back(logits[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
GGML_ASSERT(unmasked.size() == 1);
|
|
||||||
printf("Temperature %.1f test: unmasked size: %d\n", temp, (int)unmasked.size());
|
|
||||||
|
|
||||||
llama_sampler_free(chain);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
test_argmax_temp(0.0f);
|
test_argmax_temp(0.0f);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue