sampling : fix top-p

This commit is contained in:
Georgi Gerganov 2025-12-07 17:11:50 +02:00
parent 42125f0e10
commit 72e3681073
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
1 changed files with 8 additions and 4 deletions

View File

@ -1409,13 +1409,17 @@ static void llama_sampler_top_p_backend_apply(
struct ggml_tensor * idxf = ggml_sum(ctx, mask);
ggml_set_name(idxf, "dist_index_f32");
// Make top-p inclusive (i.e. return all values such that cum_sum/cdf >= p)
struct ggml_tensor * mask_reshaped = ggml_reshape_2d(ctx, mask, 1, mask->ne[0]);
// prevent out-of-bounds access
idxf = ggml_clamp(ctx, idxf, 0.0f, mask->ne[0] - 1);
// construct ones tensor to set the value in the mask
struct ggml_tensor * ones = ggml_clamp(ctx, mask_reshaped, 1.0f, 1.0f);
struct ggml_tensor * ones = ggml_scale_bias(ctx, idxf, 0.0f, 1.0f);
ggml_set_name(ones, "top_p_ones");
mask_reshaped = ggml_set_rows(ctx, mask_reshaped, ones, ggml_cast(ctx, ggml_repeat(ctx, idxf, mask), GGML_TYPE_I32));
// Make top-p inclusive (i.e. return all values such that cum_sum/cdf >= p)
struct ggml_tensor * mask_reshaped = ggml_reshape_2d(ctx, mask, 1, mask->ne[0]);
mask_reshaped = ggml_set_rows(ctx, mask_reshaped, ones, ggml_cast(ctx, idxf, GGML_TYPE_I32));
mask = ggml_reshape_1d(ctx, mask_reshaped, mask->ne[0]);
// Use ggml_scale_bias (output = (a * s) + b) which in this case becomes: