diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index dd208ccfc1..8c1570761d 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1107,14 +1107,19 @@ static void llama_sampler_top_p_backend_apply( struct llama_sampler_data * data) { auto * sctx = (llama_sampler_top_p *) smpl->ctx; + auto ggml_sort = [& ctx](struct ggml_tensor * a, struct ggml_tensor * b) { + GGML_ASSERT(ggml_nrows(a) == 1); + struct ggml_tensor * a_reshaped = ggml_reshape_2d(ctx, a, 1, a->ne[0]); + struct ggml_tensor * a_sorted = ggml_get_rows(ctx, a_reshaped, b); + return ggml_reshape_1d(ctx, a_sorted, a->ne[0]); + }; + // Get the sorted logits in descending order. struct ggml_tensor * sorted_idx = ggml_argsort(ctx, data->logits, GGML_SORT_ORDER_DESC); ggml_set_name(sorted_idx, "top_p_sorted_idx"); // Do the sorting via reshape + get_rows - struct ggml_tensor * logits_reshaped = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]); - struct ggml_tensor * sorted_logits_reshaped = ggml_get_rows(ctx, logits_reshaped, sorted_idx); - struct ggml_tensor * sorted_logits = ggml_reshape_1d(ctx, sorted_logits_reshaped, data->logits->ne[0]); + struct ggml_tensor * sorted_logits = ggml_sort(data->logits, sorted_idx); ggml_set_name(sorted_logits, "top_p_sorted_logits"); struct ggml_tensor * softmax = ggml_soft_max(ctx, sorted_logits); @@ -1122,13 +1127,7 @@ static void llama_sampler_top_p_backend_apply( // If candidates are provided, sort them as well. Otherwise, set sorted indices as candidates. if (data->candidates != nullptr) { - struct ggml_tensor * candidates_reshaped = ggml_reshape_2d(ctx, data->candidates, 1, data->candidates->ne[0]); - ggml_set_name(candidates_reshaped, "top_p_candidates_reshaped"); - - struct ggml_tensor * sorted_candidates = ggml_get_rows(ctx, candidates_reshaped, sorted_idx); - ggml_set_name(sorted_candidates, "top_p_sorted_candidates"); - - data->candidates = ggml_reshape_1d(ctx, sorted_candidates, data->candidates->ne[0]); + data->candidates = ggml_sort(data->candidates, sorted_idx); ggml_set_name(data->candidates, "top_p_candidates"); } else { data->candidates = sorted_idx;