sampling : fix candidates logic

This commit is contained in:
Georgi Gerganov 2025-12-05 14:21:08 +02:00
parent 7864074fdb
commit cf74b1a8ec
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
1 changed files with 14 additions and 6 deletions

View File

@ -1215,7 +1215,12 @@ static void llama_sampler_top_k_backend_apply(
struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k);
ggml_set_name(top_k_rows, "top_k_rows");
data->candidates = top_k;
if (data->candidates) {
data->candidates = ggml_get_rows(ctx, data->candidates, top_k);
} else {
data->candidates = top_k;
}
data->logits = ggml_reshape_1d(ctx, top_k_rows, sctx->k);
GGML_UNUSED(gf);
@ -1367,11 +1372,10 @@ static void llama_sampler_top_p_backend_apply(
// If candidates are provided, sort them as well. Otherwise, set sorted indices as candidates.
if (data->candidates != nullptr) {
data->candidates = ggml_sort(data->candidates, sorted_idx);
ggml_set_name(data->candidates, "top_p_candidates");
} else {
data->candidates = sorted_idx;
ggml_set_name(data->candidates, "top_p_candidates");
}
ggml_set_name(data->candidates, "top_p_candidates");
// Compute Cumulative Distribution Function (CDF) by means of GGML_OP_CUMSUM.
struct ggml_tensor * cdf = ggml_cumsum(ctx, softmax);
@ -1747,11 +1751,15 @@ static void llama_sampler_backend_temp_sampling(
struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits);
ggml_set_name(max_idx, "temp_max_idx");
data->candidates = max_idx;
if (data->candidates) {
data->candidates = ggml_get_rows(ctx, data->candidates, max_idx);
} else {
data->candidates = max_idx;
}
struct ggml_tensor * logit = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
struct ggml_tensor * logits = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
data->logits = ggml_get_rows(ctx, logit, max_idx);
data->logits = ggml_get_rows(ctx, logits, max_idx);
return;
}