diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index ec1812b067..d70b765e63 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1409,13 +1409,17 @@ static void llama_sampler_top_p_backend_apply( struct ggml_tensor * idxf = ggml_sum(ctx, mask); ggml_set_name(idxf, "dist_index_f32"); - // Make top-p inclusive (i.e. return all values such that cum_sum/cdf >= p) - struct ggml_tensor * mask_reshaped = ggml_reshape_2d(ctx, mask, 1, mask->ne[0]); + // prevent out-of-bounds access + idxf = ggml_clamp(ctx, idxf, 0.0f, mask->ne[0] - 1); + // construct ones tensor to set the value in the mask - struct ggml_tensor * ones = ggml_clamp(ctx, mask_reshaped, 1.0f, 1.0f); + struct ggml_tensor * ones = ggml_scale_bias(ctx, idxf, 0.0f, 1.0f); ggml_set_name(ones, "top_p_ones"); - mask_reshaped = ggml_set_rows(ctx, mask_reshaped, ones, ggml_cast(ctx, ggml_repeat(ctx, idxf, mask), GGML_TYPE_I32)); + // Make top-p inclusive (i.e. return all values such that cum_sum/cdf >= p) + struct ggml_tensor * mask_reshaped = ggml_reshape_2d(ctx, mask, 1, mask->ne[0]); + + mask_reshaped = ggml_set_rows(ctx, mask_reshaped, ones, ggml_cast(ctx, idxf, GGML_TYPE_I32)); mask = ggml_reshape_1d(ctx, mask_reshaped, mask->ne[0]); // Use ggml_scale_bias (output = (a * s) + b) which in this case becomes: