Factor out `ggml_sort` into its own function
This commit is contained in:
parent
16451d6bc3
commit
ae0bb6a6da
|
|
@ -1107,14 +1107,19 @@ static void llama_sampler_top_p_backend_apply(
|
|||
struct llama_sampler_data * data) {
|
||||
auto * sctx = (llama_sampler_top_p *) smpl->ctx;
|
||||
|
||||
auto ggml_sort = [& ctx](struct ggml_tensor * a, struct ggml_tensor * b) {
|
||||
GGML_ASSERT(ggml_nrows(a) == 1);
|
||||
struct ggml_tensor * a_reshaped = ggml_reshape_2d(ctx, a, 1, a->ne[0]);
|
||||
struct ggml_tensor * a_sorted = ggml_get_rows(ctx, a_reshaped, b);
|
||||
return ggml_reshape_1d(ctx, a_sorted, a->ne[0]);
|
||||
};
|
||||
|
||||
// Get the sorted logits in descending order.
|
||||
struct ggml_tensor * sorted_idx = ggml_argsort(ctx, data->logits, GGML_SORT_ORDER_DESC);
|
||||
ggml_set_name(sorted_idx, "top_p_sorted_idx");
|
||||
|
||||
// Do the sorting via reshape + get_rows
|
||||
struct ggml_tensor * logits_reshaped = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
|
||||
struct ggml_tensor * sorted_logits_reshaped = ggml_get_rows(ctx, logits_reshaped, sorted_idx);
|
||||
struct ggml_tensor * sorted_logits = ggml_reshape_1d(ctx, sorted_logits_reshaped, data->logits->ne[0]);
|
||||
struct ggml_tensor * sorted_logits = ggml_sort(data->logits, sorted_idx);
|
||||
ggml_set_name(sorted_logits, "top_p_sorted_logits");
|
||||
|
||||
struct ggml_tensor * softmax = ggml_soft_max(ctx, sorted_logits);
|
||||
|
|
@ -1122,13 +1127,7 @@ static void llama_sampler_top_p_backend_apply(
|
|||
|
||||
// If candidates are provided, sort them as well. Otherwise, set sorted indices as candidates.
|
||||
if (data->candidates != nullptr) {
|
||||
struct ggml_tensor * candidates_reshaped = ggml_reshape_2d(ctx, data->candidates, 1, data->candidates->ne[0]);
|
||||
ggml_set_name(candidates_reshaped, "top_p_candidates_reshaped");
|
||||
|
||||
struct ggml_tensor * sorted_candidates = ggml_get_rows(ctx, candidates_reshaped, sorted_idx);
|
||||
ggml_set_name(sorted_candidates, "top_p_sorted_candidates");
|
||||
|
||||
data->candidates = ggml_reshape_1d(ctx, sorted_candidates, data->candidates->ne[0]);
|
||||
data->candidates = ggml_sort(data->candidates, sorted_idx);
|
||||
ggml_set_name(data->candidates, "top_p_candidates");
|
||||
} else {
|
||||
data->candidates = sorted_idx;
|
||||
|
|
|
|||
Loading…
Reference in New Issue