Add initial version for top-p sampling
As we only support static graphs for the time and we don't know the size of the output of top-p, we have to do value-scaling same as for min-p operator. Further improvements can be applied to the unit-test (i.e. check for equivalence of top_p happening on backend with top_p happening on cpu) and also by constructing candidates and sorting those as opposed to reversing the sort of the logits (this would be arange + get_rows instead of argsort + get_rows)
This commit is contained in:
parent
117e2079a9
commit
333da805fe
|
|
@ -169,6 +169,7 @@ static bool common_sampler_type_has_backend_support(enum common_sampler_type typ
|
|||
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||
case COMMON_SAMPLER_TYPE_TOP_P:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
|
@ -382,6 +383,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
|
|||
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||
llama_sampler_chain_add(result->chain_backend, llama_sampler_backend_init_min_p(params.min_p));
|
||||
break;
|
||||
case COMMON_SAMPLER_TYPE_TOP_P:
|
||||
llama_sampler_chain_add(result->chain_backend, llama_sampler_backend_init_top_p(params.top_p));
|
||||
break;
|
||||
default:
|
||||
GGML_ASSERT(false && "unsupported backend sampler");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1403,6 +1403,9 @@ extern "C" {
|
|||
/// @details Min-P filtering on backend - filter tokens with a probability less than p times the maximum probability.
|
||||
LLAMA_API struct llama_sampler * llama_sampler_backend_init_min_p(float p);
|
||||
|
||||
/// @details Top-p filtering on backend - filter all tokens with cumulative pseudo-probability less than p.
|
||||
LLAMA_API struct llama_sampler * llama_sampler_backend_init_top_p(float p);
|
||||
|
||||
// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
|
||||
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
|
||||
|
||||
|
|
|
|||
|
|
@ -596,3 +596,125 @@ struct llama_sampler * llama_sampler_backend_init_min_p(float p) {
|
|||
|
||||
return sampler;
|
||||
}
|
||||
|
||||
|
||||
struct llama_sampler_backend_top_p_ctx {
|
||||
float p;
|
||||
|
||||
// Only required for checking operation support and can be removed later.
|
||||
ggml_backend_dev_t device;
|
||||
};
|
||||
|
||||
|
||||
static void llama_sampler_backend_top_p_init_ggml(
|
||||
struct llama_sampler * smpl,
|
||||
ggml_backend_buffer_type_t buft) {
|
||||
auto * sctx = (llama_sampler_backend_top_p_ctx *) smpl->ctx;
|
||||
sctx->device = ggml_backend_buft_get_device(buft);
|
||||
}
|
||||
|
||||
static void llama_sampler_backend_top_p_apply_ggml(
|
||||
struct llama_sampler * smpl,
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * gf,
|
||||
struct llama_sampler_ggml_data * ggml_data) {
|
||||
GGML_UNUSED(gf);
|
||||
|
||||
auto * sctx = (llama_sampler_backend_top_p_ctx *) smpl->ctx;
|
||||
|
||||
struct ggml_tensor * softmax = ggml_soft_max(ctx, ggml_data->logits);
|
||||
ggml_set_name(softmax, "top_p_softmax");
|
||||
|
||||
// Get the sorted indices of the softmax probabilities in descending order.
|
||||
struct ggml_tensor * sorted_idx = ggml_argsort(ctx, softmax, GGML_SORT_ORDER_DESC);
|
||||
ggml_set_name(sorted_idx, "top_p_sorted_idx");
|
||||
|
||||
// Do the sorting via reshape + get_rows
|
||||
struct ggml_tensor * softmax_reshaped = ggml_reshape_2d(ctx, softmax, 1, softmax->ne[0]);
|
||||
ggml_set_name(softmax_reshaped, "top_p_softmax_reshaped");
|
||||
|
||||
struct ggml_tensor * sorted_probs = ggml_get_rows(ctx, softmax_reshaped, sorted_idx);
|
||||
ggml_set_name(sorted_probs, "top_p_sorted_probs");
|
||||
|
||||
struct ggml_tensor * sorted_probs_reshaped = ggml_reshape_2d(ctx, sorted_probs, softmax->ne[0], 1);
|
||||
ggml_set_name(sorted_probs_reshaped, "top_p_sorted_probs_reshaped");
|
||||
// Compute Cumulative Distribution Function (CDF) by means of GGML_OP_CUMSUM.
|
||||
struct ggml_tensor * sorted_cdf = ggml_cumsum(ctx, sorted_probs_reshaped);
|
||||
ggml_set_name(sorted_cdf, "top_p_sorted_cdf");
|
||||
|
||||
// Invert CDF and add top-p value so that ggml_step yields 1 for values we want to keep
|
||||
struct ggml_tensor * sorted_cdf_scaled = ggml_scale_bias(ctx, sorted_cdf, -1.0f, sctx->p);
|
||||
ggml_set_name(sorted_cdf_scaled, "top_p_sorted_cdf_scaled");
|
||||
|
||||
struct ggml_tensor * sorted_mask = ggml_step(ctx, sorted_cdf_scaled);
|
||||
ggml_set_name(sorted_mask, "top_p_sorted_mask");
|
||||
|
||||
// reverse sorting by argsort(argsort)
|
||||
// cast to F32 since cuda only supports float inputs
|
||||
struct ggml_tensor * reverse_argsort = ggml_argsort(ctx, ggml_cast(ctx, sorted_idx, GGML_TYPE_F32), GGML_SORT_ORDER_ASC);
|
||||
ggml_set_name(reverse_argsort, "top_p_reverse_argsort");
|
||||
|
||||
// Do the sorting via reshape + get_rows
|
||||
struct ggml_tensor * sorted_reshaped_mask = ggml_reshape_2d(ctx, sorted_mask, 1, sorted_mask->ne[0]);
|
||||
ggml_set_name(sorted_reshaped_mask, "top_p_sorted_reshaped_mask");
|
||||
|
||||
struct ggml_tensor * reshaped_mask = ggml_get_rows(ctx, sorted_reshaped_mask, reverse_argsort);
|
||||
ggml_set_name(reshaped_mask, "top_p_reshaped_mask");
|
||||
|
||||
struct ggml_tensor * mask = ggml_reshape_2d(ctx, reshaped_mask, sorted_mask->ne[0], 1);
|
||||
ggml_set_name(mask, "top_p_mask");
|
||||
|
||||
// Use ggml_scale_bias (output = (a * s) + b) which in this case becomes:
|
||||
// top_p_bias = (mask * 1e9f) - 1e9f.
|
||||
// So entries in the mask that we want to discard will become -1e9f, and
|
||||
// others will be 0 (meaning that will not effect the logits).
|
||||
const float large_val = 1e9f;
|
||||
struct ggml_tensor * top_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val);
|
||||
ggml_set_name(top_p_bias, "top_p_bias");
|
||||
|
||||
ggml_data->logits = ggml_add(ctx, ggml_data->logits, top_p_bias);
|
||||
ggml_set_name(ggml_data->logits, "top_p_logits");
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_data->logits);
|
||||
}
|
||||
|
||||
static const char * llama_sampler_backend_top_p_name(const struct llama_sampler *) {
|
||||
return "backend-top-p";
|
||||
}
|
||||
|
||||
static void llama_sampler_backend_top_p_free(struct llama_sampler * smpl) {
|
||||
auto * sctx = (llama_sampler_backend_top_p_ctx *) smpl->ctx;
|
||||
delete sctx;
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_backend_top_p_clone(const struct llama_sampler * smpl) {
|
||||
auto * sctx = (llama_sampler_backend_top_p_ctx *) smpl->ctx;
|
||||
return llama_sampler_backend_init_top_p(sctx->p);
|
||||
}
|
||||
|
||||
struct llama_sampler * llama_sampler_backend_init_top_p(float p) {
|
||||
static const llama_sampler_i iface = {
|
||||
/*.name =*/ llama_sampler_backend_top_p_name,
|
||||
/*.accept =*/ nullptr,
|
||||
/*.apply =*/ nullptr,
|
||||
/*.reset =*/ nullptr,
|
||||
/*.clone =*/ llama_sampler_backend_top_p_clone,
|
||||
/*.free =*/ llama_sampler_backend_top_p_free,
|
||||
/*.apply_ggml =*/ llama_sampler_backend_top_p_apply_ggml,
|
||||
/*.accept_ggml =*/ nullptr,
|
||||
/*.set_input_ggml =*/ nullptr,
|
||||
/*.init_ggml =*/ llama_sampler_backend_top_p_init_ggml,
|
||||
};
|
||||
|
||||
auto * sctx = new llama_sampler_backend_top_p_ctx {
|
||||
/*.p =*/ p,
|
||||
/*.device =*/ nullptr,
|
||||
};
|
||||
|
||||
auto * sampler = new llama_sampler {
|
||||
/*.iface =*/ &iface,
|
||||
/*.ctx =*/ sctx,
|
||||
};
|
||||
|
||||
return sampler;
|
||||
}
|
||||
|
|
@ -481,6 +481,61 @@ static void test_backend_min_p_sampling(const char * model_path) {
|
|||
llama_sampler_free(chain);
|
||||
}
|
||||
|
||||
static void test_backend_top_p_sampling(const char * model_path) {
|
||||
test_model_context test_ctx;
|
||||
|
||||
const int seq_id = 0;
|
||||
const float p = 0.9;
|
||||
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
|
||||
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_top_p(p));
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain }};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!test_ctx.decode({{seq_id, "Hello"}})) {
|
||||
return;
|
||||
}
|
||||
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
|
||||
|
||||
float * logits = llama_get_backend_sampled_logits_ith(test_ctx.ctx, batch_idx);
|
||||
uint32_t n_logits = llama_get_backend_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
|
||||
|
||||
// Print the logits that are above the min-p threshold
|
||||
std::vector<float> filtered_logits;
|
||||
for (size_t i = 0; i < n_logits; ++i) {
|
||||
if (logits[i] > -1e9f) {
|
||||
filtered_logits.push_back(logits[i]);
|
||||
}
|
||||
}
|
||||
GGML_ASSERT(filtered_logits.size() < (size_t) test_ctx.n_vocab);
|
||||
|
||||
// Sample using CPU sampler for verification to inspect they are reasonable
|
||||
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_dist(88));
|
||||
|
||||
llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx);
|
||||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
||||
printf("top-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
||||
// Decode and sampler 10 more tokens
|
||||
for (int i = 0; i < 10; i++) {
|
||||
int32_t loop_idx = test_ctx.idx_for_seq(seq_id);
|
||||
llama_token token = llama_sampler_sample(chain, test_ctx.ctx, loop_idx);
|
||||
printf("top-p gen step %d: token id :%5.d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str());
|
||||
test_ctx.decode_token(token, 0);
|
||||
}
|
||||
|
||||
printf("top-p sampling test PASSED\n");
|
||||
|
||||
llama_sampler_free(chain);
|
||||
}
|
||||
|
||||
static void test_backend_multi_sequence_sampling(const char * model_path) {
|
||||
test_model_context test_ctx;
|
||||
|
||||
|
|
@ -934,6 +989,7 @@ static const backend_test_case BACKEND_TESTS[] = {
|
|||
{ "mixed", test_backend_mixed_sampling, true },
|
||||
{ "min_p", test_backend_min_p_sampling, true },
|
||||
{ "cpu_mixed", test_backend_cpu_mixed_batch, true },
|
||||
{ "top_p", test_backend_top_p_sampling, true },
|
||||
};
|
||||
|
||||
struct backend_cli_args {
|
||||
|
|
|
|||
Loading…
Reference in New Issue