sampling : also apply sorting in backend path when blue noise rng is selected
This commit is contained in:
parent
267cd808a2
commit
1b1b2cbe0e
|
|
@ -1361,6 +1361,30 @@ static void llama_sampler_dist_backend_apply(
|
|||
ggml_set_name (sctx->inp_uniform, "uniform");
|
||||
ggml_set_input(sctx->inp_uniform);
|
||||
|
||||
// If the RNG requires sorted input (e.g., blue noise), sort logits first
|
||||
// so the CDF walk operates in probability-rank space, not arbitrary vocab order.
|
||||
if (sctx->rng->requires_sorted()) {
|
||||
auto ggml_sort = [ctx](struct ggml_tensor * a, struct ggml_tensor * b) {
|
||||
GGML_ASSERT(ggml_nrows(a) == 1);
|
||||
struct ggml_tensor * a_reshaped = ggml_reshape_2d(ctx, a, 1, a->ne[0]);
|
||||
struct ggml_tensor * a_sorted = ggml_get_rows(ctx, a_reshaped, b);
|
||||
return ggml_reshape_1d(ctx, a_sorted, a->ne[0]);
|
||||
};
|
||||
|
||||
struct ggml_tensor * sorted_idx = ggml_argsort(ctx, data->logits, GGML_SORT_ORDER_DESC);
|
||||
ggml_set_name(sorted_idx, "dist_sorted_idx");
|
||||
|
||||
data->logits = ggml_sort(data->logits, sorted_idx);
|
||||
ggml_set_name(data->logits, "dist_sorted_logits");
|
||||
|
||||
if (data->candidates) {
|
||||
data->candidates = ggml_sort(data->candidates, sorted_idx);
|
||||
} else {
|
||||
data->candidates = sorted_idx;
|
||||
}
|
||||
ggml_set_name(data->candidates, "dist_sorted_candidates");
|
||||
}
|
||||
|
||||
struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits);
|
||||
ggml_set_name(probs, "dist_probs");
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue