From cf74b1a8ecd239635a55d2dff8005b0e060ff14f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 5 Dec 2025 14:21:08 +0200 Subject: [PATCH] sampling : fix candidates logic --- src/llama-sampling.cpp | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 004284c6be..a37e8a8223 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1215,7 +1215,12 @@ static void llama_sampler_top_k_backend_apply( struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k); ggml_set_name(top_k_rows, "top_k_rows"); - data->candidates = top_k; + if (data->candidates) { + data->candidates = ggml_get_rows(ctx, data->candidates, top_k); + } else { + data->candidates = top_k; + } + data->logits = ggml_reshape_1d(ctx, top_k_rows, sctx->k); GGML_UNUSED(gf); @@ -1367,11 +1372,10 @@ static void llama_sampler_top_p_backend_apply( // If candidates are provided, sort them as well. Otherwise, set sorted indices as candidates. if (data->candidates != nullptr) { data->candidates = ggml_sort(data->candidates, sorted_idx); - ggml_set_name(data->candidates, "top_p_candidates"); } else { data->candidates = sorted_idx; - ggml_set_name(data->candidates, "top_p_candidates"); } + ggml_set_name(data->candidates, "top_p_candidates"); // Compute Cumulative Distribution Function (CDF) by means of GGML_OP_CUMSUM. struct ggml_tensor * cdf = ggml_cumsum(ctx, softmax); @@ -1747,11 +1751,15 @@ static void llama_sampler_backend_temp_sampling( struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits); ggml_set_name(max_idx, "temp_max_idx"); - data->candidates = max_idx; + if (data->candidates) { + data->candidates = ggml_get_rows(ctx, data->candidates, max_idx); + } else { + data->candidates = max_idx; + } - struct ggml_tensor * logit = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]); + struct ggml_tensor * logits = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]); - data->logits = ggml_get_rows(ctx, logit, max_idx); + data->logits = ggml_get_rows(ctx, logits, max_idx); return; }