diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index eb924bd503..dd208ccfc1 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1107,46 +1107,44 @@ static void llama_sampler_top_p_backend_apply( struct llama_sampler_data * data) { auto * sctx = (llama_sampler_top_p *) smpl->ctx; - struct ggml_tensor * softmax = ggml_soft_max(ctx, data->logits); - ggml_set_name(softmax, "top_p_softmax"); - - // Get the sorted indices of the softmax probabilities in descending order. - struct ggml_tensor * sorted_idx = ggml_argsort(ctx, softmax, GGML_SORT_ORDER_DESC); + // Get the sorted logits in descending order. + struct ggml_tensor * sorted_idx = ggml_argsort(ctx, data->logits, GGML_SORT_ORDER_DESC); ggml_set_name(sorted_idx, "top_p_sorted_idx"); // Do the sorting via reshape + get_rows - struct ggml_tensor * softmax_reshaped = ggml_reshape_2d(ctx, softmax, 1, softmax->ne[0]); - ggml_set_name(softmax_reshaped, "top_p_softmax_reshaped"); + struct ggml_tensor * logits_reshaped = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]); + struct ggml_tensor * sorted_logits_reshaped = ggml_get_rows(ctx, logits_reshaped, sorted_idx); + struct ggml_tensor * sorted_logits = ggml_reshape_1d(ctx, sorted_logits_reshaped, data->logits->ne[0]); + ggml_set_name(sorted_logits, "top_p_sorted_logits"); - struct ggml_tensor * sorted_probs = ggml_get_rows(ctx, softmax_reshaped, sorted_idx); - ggml_set_name(sorted_probs, "top_p_sorted_probs"); + struct ggml_tensor * softmax = ggml_soft_max(ctx, sorted_logits); + ggml_set_name(softmax, "top_p_softmax"); + + // If candidates are provided, sort them as well. Otherwise, set sorted indices as candidates. + if (data->candidates != nullptr) { + struct ggml_tensor * candidates_reshaped = ggml_reshape_2d(ctx, data->candidates, 1, data->candidates->ne[0]); + ggml_set_name(candidates_reshaped, "top_p_candidates_reshaped"); + + struct ggml_tensor * sorted_candidates = ggml_get_rows(ctx, candidates_reshaped, sorted_idx); + ggml_set_name(sorted_candidates, "top_p_sorted_candidates"); + + data->candidates = ggml_reshape_1d(ctx, sorted_candidates, data->candidates->ne[0]); + ggml_set_name(data->candidates, "top_p_candidates"); + } else { + data->candidates = sorted_idx; + ggml_set_name(data->candidates, "top_p_candidates"); + } - struct ggml_tensor * sorted_probs_reshaped = ggml_reshape_2d(ctx, sorted_probs, softmax->ne[0], 1); - ggml_set_name(sorted_probs_reshaped, "top_p_sorted_probs_reshaped"); // Compute Cumulative Distribution Function (CDF) by means of GGML_OP_CUMSUM. - struct ggml_tensor * sorted_cdf = ggml_cumsum(ctx, sorted_probs_reshaped); - ggml_set_name(sorted_cdf, "top_p_sorted_cdf"); + struct ggml_tensor * cdf = ggml_cumsum(ctx, softmax); + ggml_set_name(cdf, "top_p_cdf"); + // TODO: Make it inclusive of probability p // Invert CDF and add top-p value so that ggml_step yields 1 for values we want to keep - struct ggml_tensor * sorted_cdf_scaled = ggml_scale_bias(ctx, sorted_cdf, -1.0f, sctx->p); - ggml_set_name(sorted_cdf_scaled, "top_p_sorted_cdf_scaled"); + struct ggml_tensor * cdf_scaled = ggml_scale_bias(ctx, cdf, -1.0f, sctx->p); + ggml_set_name(cdf_scaled, "top_p_cdf_scaled"); - struct ggml_tensor * sorted_mask = ggml_step(ctx, sorted_cdf_scaled); - ggml_set_name(sorted_mask, "top_p_sorted_mask"); - - // reverse sorting by argsort(argsort) - // cast to F32 since cuda only supports float inputs - struct ggml_tensor * reverse_argsort = ggml_argsort(ctx, ggml_cast(ctx, sorted_idx, GGML_TYPE_F32), GGML_SORT_ORDER_ASC); - ggml_set_name(reverse_argsort, "top_p_reverse_argsort"); - - // Do the sorting via reshape + get_rows - struct ggml_tensor * sorted_reshaped_mask = ggml_reshape_2d(ctx, sorted_mask, 1, sorted_mask->ne[0]); - ggml_set_name(sorted_reshaped_mask, "top_p_sorted_reshaped_mask"); - - struct ggml_tensor * reshaped_mask = ggml_get_rows(ctx, sorted_reshaped_mask, reverse_argsort); - ggml_set_name(reshaped_mask, "top_p_reshaped_mask"); - - struct ggml_tensor * mask = ggml_reshape_2d(ctx, reshaped_mask, sorted_mask->ne[0], 1); + struct ggml_tensor * mask = ggml_step(ctx, cdf_scaled); ggml_set_name(mask, "top_p_mask"); // Use ggml_scale_bias (output = (a * s) + b) which in this case becomes: @@ -1157,9 +1155,13 @@ static void llama_sampler_top_p_backend_apply( struct ggml_tensor * top_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val); ggml_set_name(top_p_bias, "top_p_bias"); - data->logits = ggml_add(ctx, data->logits, top_p_bias); + data->logits = ggml_add(ctx, sorted_logits, top_p_bias); ggml_set_name(data->logits, "top_p_logits"); + ggml_set_output(data->candidates); + ggml_build_forward_expand(gf, data->candidates); + + ggml_set_output(data->logits); ggml_build_forward_expand(gf, data->logits); }