Factor out `ggml_sort` into its own function

This commit is contained in:
Oliver Simons 2025-12-01 14:46:47 +01:00
parent 16451d6bc3
commit ae0bb6a6da
1 changed files with 9 additions and 10 deletions

View File

@ -1107,14 +1107,19 @@ static void llama_sampler_top_p_backend_apply(
struct llama_sampler_data * data) {
auto * sctx = (llama_sampler_top_p *) smpl->ctx;
auto ggml_sort = [& ctx](struct ggml_tensor * a, struct ggml_tensor * b) {
GGML_ASSERT(ggml_nrows(a) == 1);
struct ggml_tensor * a_reshaped = ggml_reshape_2d(ctx, a, 1, a->ne[0]);
struct ggml_tensor * a_sorted = ggml_get_rows(ctx, a_reshaped, b);
return ggml_reshape_1d(ctx, a_sorted, a->ne[0]);
};
// Get the sorted logits in descending order.
struct ggml_tensor * sorted_idx = ggml_argsort(ctx, data->logits, GGML_SORT_ORDER_DESC);
ggml_set_name(sorted_idx, "top_p_sorted_idx");
// Do the sorting via reshape + get_rows
struct ggml_tensor * logits_reshaped = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
struct ggml_tensor * sorted_logits_reshaped = ggml_get_rows(ctx, logits_reshaped, sorted_idx);
struct ggml_tensor * sorted_logits = ggml_reshape_1d(ctx, sorted_logits_reshaped, data->logits->ne[0]);
struct ggml_tensor * sorted_logits = ggml_sort(data->logits, sorted_idx);
ggml_set_name(sorted_logits, "top_p_sorted_logits");
struct ggml_tensor * softmax = ggml_soft_max(ctx, sorted_logits);
@ -1122,13 +1127,7 @@ static void llama_sampler_top_p_backend_apply(
// If candidates are provided, sort them as well. Otherwise, set sorted indices as candidates.
if (data->candidates != nullptr) {
struct ggml_tensor * candidates_reshaped = ggml_reshape_2d(ctx, data->candidates, 1, data->candidates->ne[0]);
ggml_set_name(candidates_reshaped, "top_p_candidates_reshaped");
struct ggml_tensor * sorted_candidates = ggml_get_rows(ctx, candidates_reshaped, sorted_idx);
ggml_set_name(sorted_candidates, "top_p_sorted_candidates");
data->candidates = ggml_reshape_1d(ctx, sorted_candidates, data->candidates->ne[0]);
data->candidates = ggml_sort(data->candidates, sorted_idx);
ggml_set_name(data->candidates, "top_p_candidates");
} else {
data->candidates = sorted_idx;