sampling : also apply sorting in backend path when blue noise rng is selected

This commit is contained in:
Jan Boon 2026-02-07 04:27:52 +00:00
parent 267cd808a2
commit 1b1b2cbe0e
1 changed files with 24 additions and 0 deletions

View File

@ -1361,6 +1361,30 @@ static void llama_sampler_dist_backend_apply(
ggml_set_name (sctx->inp_uniform, "uniform");
ggml_set_input(sctx->inp_uniform);
// If the RNG requires sorted input (e.g., blue noise), sort logits first
// so the CDF walk operates in probability-rank space, not arbitrary vocab order.
if (sctx->rng->requires_sorted()) {
auto ggml_sort = [ctx](struct ggml_tensor * a, struct ggml_tensor * b) {
GGML_ASSERT(ggml_nrows(a) == 1);
struct ggml_tensor * a_reshaped = ggml_reshape_2d(ctx, a, 1, a->ne[0]);
struct ggml_tensor * a_sorted = ggml_get_rows(ctx, a_reshaped, b);
return ggml_reshape_1d(ctx, a_sorted, a->ne[0]);
};
struct ggml_tensor * sorted_idx = ggml_argsort(ctx, data->logits, GGML_SORT_ORDER_DESC);
ggml_set_name(sorted_idx, "dist_sorted_idx");
data->logits = ggml_sort(data->logits, sorted_idx);
ggml_set_name(data->logits, "dist_sorted_logits");
if (data->candidates) {
data->candidates = ggml_sort(data->candidates, sorted_idx);
} else {
data->candidates = sorted_idx;
}
ggml_set_name(data->candidates, "dist_sorted_candidates");
}
struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits);
ggml_set_name(probs, "dist_probs");