sampling : optimize logit_bias sampler

This commit is contained in:
Georgi Gerganov 2025-12-11 11:14:39 +02:00
parent 56720f8f01
commit 54e9054017
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
1 changed files with 30 additions and 11 deletions

View File

@ -3223,6 +3223,7 @@ struct llama_sampler_logit_bias : public llama_sampler_backend {
std::vector<llama_logit_bias> to_search;
struct ggml_tensor * inp_logit_bias;
struct ggml_tensor * inp_logit_idxs;
ggml_context_ptr inp_ctx;
ggml_backend_buffer_ptr inp_buf;
@ -3288,9 +3289,14 @@ static void llama_sampler_logit_bias_backend_apply(
return;
}
// Add the sparse logit logit_bias to the logits
struct ggml_tensor * logit_biased = ggml_add(ctx, data->logits, sctx->inp_logit_bias);
data->logits = logit_biased;
//struct ggml_tensor * logit_biased = ggml_add(ctx, data->logits, sctx->inp_logit_bias);
ggml_tensor * cur = ggml_fill(ctx, data->logits, 0.0f);
cur = ggml_reshape_2d(ctx, cur, 1, ggml_nelements(cur));
cur = ggml_set_rows(ctx, cur, sctx->inp_logit_bias, sctx->inp_logit_idxs);
cur = ggml_reshape_1d(ctx, cur, ggml_nelements(cur));
data->logits = ggml_add(ctx, data->logits, cur);
}
static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * smpl) {
@ -3298,16 +3304,23 @@ static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * sm
if (sctx->logit_bias.empty()) {
return;
}
GGML_ASSERT(sctx->inp_logit_bias != nullptr);
// Create a sparse logit_bias vector from the logit_bias entries.
std::vector<float> logit_bias_sparse(sctx->n_vocab, 0.0f);
for (const auto & lb : sctx->logit_bias) {
GGML_ASSERT(sctx->inp_logit_bias != nullptr);
GGML_ASSERT(sctx->inp_logit_idxs != nullptr);
const size_t n = sctx->logit_bias.size();
std::vector<float> data_logit_bias(n, 0.0f);
std::vector<int32_t> data_logit_idxs(n, 0);
for (size_t i = 0; i < n; ++i) {
const auto & lb = sctx->logit_bias[i];
GGML_ASSERT(lb.token >= 0 && lb.token < (int32_t) sctx->n_vocab);
logit_bias_sparse[lb.token] = lb.bias;
data_logit_bias[i] = lb.bias;
data_logit_idxs[i] = lb.token;
}
ggml_backend_tensor_set(sctx->inp_logit_bias, logit_bias_sparse.data(), 0, ggml_nbytes(sctx->inp_logit_bias));
ggml_backend_tensor_set(sctx->inp_logit_bias, data_logit_bias.data(), 0, ggml_nbytes(sctx->inp_logit_bias));
ggml_backend_tensor_set(sctx->inp_logit_idxs, data_logit_idxs.data(), 0, ggml_nbytes(sctx->inp_logit_idxs));
}
static bool llama_sampler_logit_bias_backend_init(
@ -3322,17 +3335,23 @@ static bool llama_sampler_logit_bias_backend_init(
}
ggml_init_params params = {
/*.mem_size =*/ ggml_tensor_overhead(),
/*.mem_size =*/ 2*ggml_tensor_overhead(),
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};
sctx->inp_ctx.reset(ggml_init(params));
sctx->inp_logit_bias = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, sctx->n_vocab);
const size_t n = sctx->logit_bias.size();
sctx->inp_logit_bias = ggml_new_tensor_2d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1, n);
ggml_set_name(sctx->inp_logit_bias, "logit_bias");
ggml_set_input(sctx->inp_logit_bias);
sctx->inp_logit_idxs = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_I32, n);
ggml_set_name(sctx->inp_logit_idxs, "logit_idxs");
ggml_set_input(sctx->inp_logit_idxs);
// Allocate all tensors from our context to the backend
sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));