sampling : simplify

This commit is contained in:
Georgi Gerganov 2025-11-28 17:21:12 +02:00
parent 8cac9dee45
commit 2464d1b3fc
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
1 changed files with 3 additions and 8 deletions

View File

@ -149,8 +149,7 @@ static void llama_sampler_backend_top_k_apply_ggml(
fprintf(stderr, "CPU backend will be used instead which defeats the purpose of having backend samplers\n");
}
// TODO: temporary cont until https://github.com/ggml-org/llama.cpp/pull/17365 is merged
ggml_data->candidates = ggml_cont(ctx, top_k);
ggml_data->candidates = top_k;
struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, ggml_data->logits, 1, ggml_data->logits->ne[0]);
struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k);
@ -525,12 +524,8 @@ static void llama_sampler_backend_min_p_apply_ggml(
struct ggml_tensor * threshold = ggml_scale_bias(ctx, max_logit, 1.0f, logf(sctx->p));
ggml_set_name(threshold, "min_p_threshold");
// Broadcast the threshold to match the shape of logits.
struct ggml_tensor * threshold_b = ggml_repeat(ctx, threshold, ggml_data->logits);
ggml_set_name(threshold_b, "min_p_threshold_b");
// Subtract the threshold from logits.
struct ggml_tensor * sub = ggml_sub(ctx, ggml_data->logits, threshold_b);
struct ggml_tensor * sub = ggml_sub(ctx, ggml_data->logits, threshold);
// Create a mask where logits below the threshold are 0 (discard),
// and others are 1 (keep).
@ -713,4 +708,4 @@ struct llama_sampler * llama_sampler_backend_init_top_p(float p) {
};
return sampler;
}
}