diff --git a/common/sampling.cpp b/common/sampling.cpp index a831eac18b..94367bd307 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -169,6 +169,7 @@ static bool common_sampler_type_has_backend_support(enum common_sampler_type typ case COMMON_SAMPLER_TYPE_TOP_K: case COMMON_SAMPLER_TYPE_TEMPERATURE: case COMMON_SAMPLER_TYPE_MIN_P: + case COMMON_SAMPLER_TYPE_TOP_P: return true; default: return false; @@ -382,6 +383,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st case COMMON_SAMPLER_TYPE_MIN_P: llama_sampler_chain_add(result->chain_backend, llama_sampler_backend_init_min_p(params.min_p)); break; + case COMMON_SAMPLER_TYPE_TOP_P: + llama_sampler_chain_add(result->chain_backend, llama_sampler_backend_init_top_p(params.top_p)); + break; default: GGML_ASSERT(false && "unsupported backend sampler"); } diff --git a/include/llama.h b/include/llama.h index 080ac27f1f..38178d919f 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1403,6 +1403,9 @@ extern "C" { /// @details Min-P filtering on backend - filter tokens with a probability less than p times the maximum probability. LLAMA_API struct llama_sampler * llama_sampler_backend_init_min_p(float p); + /// @details Top-p filtering on backend - filter all tokens with cumulative pseudo-probability less than p. + LLAMA_API struct llama_sampler * llama_sampler_backend_init_top_p(float p); + // Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl); diff --git a/src/llama-backend-sampler.cpp b/src/llama-backend-sampler.cpp index 6a3893b129..a4f22055e6 100644 --- a/src/llama-backend-sampler.cpp +++ b/src/llama-backend-sampler.cpp @@ -596,3 +596,125 @@ struct llama_sampler * llama_sampler_backend_init_min_p(float p) { return sampler; } + + +struct llama_sampler_backend_top_p_ctx { + float p; + + // Only required for checking operation support and can be removed later. + ggml_backend_dev_t device; +}; + + +static void llama_sampler_backend_top_p_init_ggml( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_backend_top_p_ctx *) smpl->ctx; + sctx->device = ggml_backend_buft_get_device(buft); +} + +static void llama_sampler_backend_top_p_apply_ggml( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_ggml_data * ggml_data) { + GGML_UNUSED(gf); + + auto * sctx = (llama_sampler_backend_top_p_ctx *) smpl->ctx; + + struct ggml_tensor * softmax = ggml_soft_max(ctx, ggml_data->logits); + ggml_set_name(softmax, "top_p_softmax"); + + // Get the sorted indices of the softmax probabilities in descending order. + struct ggml_tensor * sorted_idx = ggml_argsort(ctx, softmax, GGML_SORT_ORDER_DESC); + ggml_set_name(sorted_idx, "top_p_sorted_idx"); + + // Do the sorting via reshape + get_rows + struct ggml_tensor * softmax_reshaped = ggml_reshape_2d(ctx, softmax, 1, softmax->ne[0]); + ggml_set_name(softmax_reshaped, "top_p_softmax_reshaped"); + + struct ggml_tensor * sorted_probs = ggml_get_rows(ctx, softmax_reshaped, sorted_idx); + ggml_set_name(sorted_probs, "top_p_sorted_probs"); + + struct ggml_tensor * sorted_probs_reshaped = ggml_reshape_2d(ctx, sorted_probs, softmax->ne[0], 1); + ggml_set_name(sorted_probs_reshaped, "top_p_sorted_probs_reshaped"); + // Compute Cumulative Distribution Function (CDF) by means of GGML_OP_CUMSUM. + struct ggml_tensor * sorted_cdf = ggml_cumsum(ctx, sorted_probs_reshaped); + ggml_set_name(sorted_cdf, "top_p_sorted_cdf"); + + // Invert CDF and add top-p value so that ggml_step yields 1 for values we want to keep + struct ggml_tensor * sorted_cdf_scaled = ggml_scale_bias(ctx, sorted_cdf, -1.0f, sctx->p); + ggml_set_name(sorted_cdf_scaled, "top_p_sorted_cdf_scaled"); + + struct ggml_tensor * sorted_mask = ggml_step(ctx, sorted_cdf_scaled); + ggml_set_name(sorted_mask, "top_p_sorted_mask"); + + // reverse sorting by argsort(argsort) + // cast to F32 since cuda only supports float inputs + struct ggml_tensor * reverse_argsort = ggml_argsort(ctx, ggml_cast(ctx, sorted_idx, GGML_TYPE_F32), GGML_SORT_ORDER_ASC); + ggml_set_name(reverse_argsort, "top_p_reverse_argsort"); + + // Do the sorting via reshape + get_rows + struct ggml_tensor * sorted_reshaped_mask = ggml_reshape_2d(ctx, sorted_mask, 1, sorted_mask->ne[0]); + ggml_set_name(sorted_reshaped_mask, "top_p_sorted_reshaped_mask"); + + struct ggml_tensor * reshaped_mask = ggml_get_rows(ctx, sorted_reshaped_mask, reverse_argsort); + ggml_set_name(reshaped_mask, "top_p_reshaped_mask"); + + struct ggml_tensor * mask = ggml_reshape_2d(ctx, reshaped_mask, sorted_mask->ne[0], 1); + ggml_set_name(mask, "top_p_mask"); + + // Use ggml_scale_bias (output = (a * s) + b) which in this case becomes: + // top_p_bias = (mask * 1e9f) - 1e9f. + // So entries in the mask that we want to discard will become -1e9f, and + // others will be 0 (meaning that will not effect the logits). + const float large_val = 1e9f; + struct ggml_tensor * top_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val); + ggml_set_name(top_p_bias, "top_p_bias"); + + ggml_data->logits = ggml_add(ctx, ggml_data->logits, top_p_bias); + ggml_set_name(ggml_data->logits, "top_p_logits"); + + ggml_build_forward_expand(gf, ggml_data->logits); +} + +static const char * llama_sampler_backend_top_p_name(const struct llama_sampler *) { + return "backend-top-p"; +} + +static void llama_sampler_backend_top_p_free(struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_backend_top_p_ctx *) smpl->ctx; + delete sctx; +} + +static struct llama_sampler * llama_sampler_backend_top_p_clone(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_backend_top_p_ctx *) smpl->ctx; + return llama_sampler_backend_init_top_p(sctx->p); +} + +struct llama_sampler * llama_sampler_backend_init_top_p(float p) { + static const llama_sampler_i iface = { + /*.name =*/ llama_sampler_backend_top_p_name, + /*.accept =*/ nullptr, + /*.apply =*/ nullptr, + /*.reset =*/ nullptr, + /*.clone =*/ llama_sampler_backend_top_p_clone, + /*.free =*/ llama_sampler_backend_top_p_free, + /*.apply_ggml =*/ llama_sampler_backend_top_p_apply_ggml, + /*.accept_ggml =*/ nullptr, + /*.set_input_ggml =*/ nullptr, + /*.init_ggml =*/ llama_sampler_backend_top_p_init_ggml, + }; + + auto * sctx = new llama_sampler_backend_top_p_ctx { + /*.p =*/ p, + /*.device =*/ nullptr, + }; + + auto * sampler = new llama_sampler { + /*.iface =*/ &iface, + /*.ctx =*/ sctx, + }; + + return sampler; +} \ No newline at end of file diff --git a/tests/test-backend-sampler.cpp b/tests/test-backend-sampler.cpp index cd9aa003b5..47d2a139ea 100644 --- a/tests/test-backend-sampler.cpp +++ b/tests/test-backend-sampler.cpp @@ -481,6 +481,61 @@ static void test_backend_min_p_sampling(const char * model_path) { llama_sampler_free(chain); } +static void test_backend_top_p_sampling(const char * model_path) { + test_model_context test_ctx; + + const int seq_id = 0; + const float p = 0.9; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params); + llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_top_p(p)); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain }}; + + if (!test_ctx.setup(model_path, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Hello"}})) { + return; + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + + float * logits = llama_get_backend_sampled_logits_ith(test_ctx.ctx, batch_idx); + uint32_t n_logits = llama_get_backend_sampled_logits_count_ith(test_ctx.ctx, batch_idx); + + // Print the logits that are above the min-p threshold + std::vector filtered_logits; + for (size_t i = 0; i < n_logits; ++i) { + if (logits[i] > -1e9f) { + filtered_logits.push_back(logits[i]); + } + } + GGML_ASSERT(filtered_logits.size() < (size_t) test_ctx.n_vocab); + + // Sample using CPU sampler for verification to inspect they are reasonable + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * chain = llama_sampler_chain_init(chain_params); + llama_sampler_chain_add(chain, llama_sampler_init_dist(88)); + + llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("top-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + + // Decode and sampler 10 more tokens + for (int i = 0; i < 10; i++) { + int32_t loop_idx = test_ctx.idx_for_seq(seq_id); + llama_token token = llama_sampler_sample(chain, test_ctx.ctx, loop_idx); + printf("top-p gen step %d: token id :%5.d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str()); + test_ctx.decode_token(token, 0); + } + + printf("top-p sampling test PASSED\n"); + + llama_sampler_free(chain); +} + static void test_backend_multi_sequence_sampling(const char * model_path) { test_model_context test_ctx; @@ -934,6 +989,7 @@ static const backend_test_case BACKEND_TESTS[] = { { "mixed", test_backend_mixed_sampling, true }, { "min_p", test_backend_min_p_sampling, true }, { "cpu_mixed", test_backend_cpu_mixed_batch, true }, + { "top_p", test_backend_top_p_sampling, true }, }; struct backend_cli_args {