sampling : use argmax for min-p sampling
This commit is contained in:
parent
7c2bfb352e
commit
d9d736102b
|
|
@ -516,24 +516,17 @@ static void llama_sampler_backend_min_p_apply_ggml(
|
|||
ggml_set_name(softmax, "softmax");
|
||||
|
||||
// Get the sorted indices of the softmax probabilities in descending order.
|
||||
struct ggml_tensor * sorted_idx = ggml_argsort(ctx, softmax, GGML_SORT_ORDER_DESC);
|
||||
ggml_set_name(sorted_idx, "sorted_idx");
|
||||
struct ggml_tensor * max_idx = ggml_argmax(ctx, softmax);
|
||||
ggml_set_name(max_idx, "max_idx");
|
||||
|
||||
// Reshape into a row vector.
|
||||
struct ggml_tensor * softmax_rows = ggml_reshape_2d(ctx, softmax, 1, softmax->ne[0]);
|
||||
ggml_set_name(softmax_rows, "softmax_rows");
|
||||
|
||||
// Get the sorted probabilities using the sorted indices so that we can get
|
||||
// the max probability value, which will be the first entry in sorted_probs.
|
||||
struct ggml_tensor * sorted_probs = ggml_get_rows(ctx, softmax_rows, sorted_idx);
|
||||
ggml_set_name(sorted_probs, "sorted_probs");
|
||||
|
||||
// Get the max probability value from sorted_probs.
|
||||
struct ggml_tensor * p_max = ggml_view_1d(ctx, sorted_probs, 1, 0);
|
||||
ggml_set_name(p_max, "p_max");
|
||||
struct ggml_tensor * max_prob = ggml_get_rows(ctx, softmax_rows, max_idx);
|
||||
ggml_set_name(max_prob, "max_prob");
|
||||
|
||||
// Calculate the threshold value.
|
||||
struct ggml_tensor * threshold = ggml_scale(ctx, p_max, sctx->p);
|
||||
struct ggml_tensor * threshold = ggml_scale(ctx, max_prob, sctx->p);
|
||||
ggml_set_name(threshold, "min_p_threshold");
|
||||
|
||||
// Broadcast the threshold to match the shape of softmax.
|
||||
|
|
|
|||
Loading…
Reference in New Issue