sampling : use logits directly for min-p filtering
This commit is contained in:
parent
333da805fe
commit
8cac9dee45
|
|
@ -512,31 +512,27 @@ static void llama_sampler_backend_min_p_apply_ggml(
|
|||
|
||||
auto * sctx = (llama_sampler_backend_min_p_ctx *) smpl->ctx;
|
||||
|
||||
struct ggml_tensor * softmax = ggml_soft_max(ctx, ggml_data->logits);
|
||||
ggml_set_name(softmax, "softmax");
|
||||
|
||||
// Get the sorted indices of the softmax probabilities in descending order.
|
||||
struct ggml_tensor * max_idx = ggml_argmax(ctx, softmax);
|
||||
struct ggml_tensor * max_idx = ggml_argmax(ctx, ggml_data->logits);
|
||||
ggml_set_name(max_idx, "max_idx");
|
||||
|
||||
struct ggml_tensor * softmax_rows = ggml_reshape_2d(ctx, softmax, 1, softmax->ne[0]);
|
||||
ggml_set_name(softmax_rows, "softmax_rows");
|
||||
struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, ggml_data->logits, 1, ggml_data->logits->ne[0]);
|
||||
ggml_set_name(logits_rows, "logits_rows");
|
||||
|
||||
struct ggml_tensor * max_prob = ggml_get_rows(ctx, softmax_rows, max_idx);
|
||||
ggml_set_name(max_prob, "max_prob");
|
||||
struct ggml_tensor * max_logit = ggml_get_rows(ctx, logits_rows, max_idx);
|
||||
ggml_set_name(max_logit, "max_logit");
|
||||
|
||||
// Calculate the threshold value.
|
||||
struct ggml_tensor * threshold = ggml_scale(ctx, max_prob, sctx->p);
|
||||
struct ggml_tensor * threshold = ggml_scale_bias(ctx, max_logit, 1.0f, logf(sctx->p));
|
||||
ggml_set_name(threshold, "min_p_threshold");
|
||||
|
||||
// Broadcast the threshold to match the shape of softmax.
|
||||
struct ggml_tensor * threshold_b = ggml_repeat(ctx, threshold, softmax);
|
||||
// Broadcast the threshold to match the shape of logits.
|
||||
struct ggml_tensor * threshold_b = ggml_repeat(ctx, threshold, ggml_data->logits);
|
||||
ggml_set_name(threshold_b, "min_p_threshold_b");
|
||||
|
||||
// Subtract the threshold from softmax probabilities.
|
||||
struct ggml_tensor * sub = ggml_sub(ctx, softmax, threshold_b);
|
||||
// Subtract the threshold from logits.
|
||||
struct ggml_tensor * sub = ggml_sub(ctx, ggml_data->logits, threshold_b);
|
||||
|
||||
// Create a mask where probabilities below the threshold are 0 (discard),
|
||||
// Create a mask where logits below the threshold are 0 (discard),
|
||||
// and others are 1 (keep).
|
||||
struct ggml_tensor * mask = ggml_step(ctx, sub);
|
||||
ggml_set_name(mask, "min_p_mask");
|
||||
|
|
|
|||
Loading…
Reference in New Issue