sampling : optimize logit_bias sampler
This commit is contained in:
parent
56720f8f01
commit
54e9054017
|
|
@ -3223,6 +3223,7 @@ struct llama_sampler_logit_bias : public llama_sampler_backend {
|
|||
std::vector<llama_logit_bias> to_search;
|
||||
|
||||
struct ggml_tensor * inp_logit_bias;
|
||||
struct ggml_tensor * inp_logit_idxs;
|
||||
|
||||
ggml_context_ptr inp_ctx;
|
||||
ggml_backend_buffer_ptr inp_buf;
|
||||
|
|
@ -3288,9 +3289,14 @@ static void llama_sampler_logit_bias_backend_apply(
|
|||
return;
|
||||
}
|
||||
|
||||
// Add the sparse logit logit_bias to the logits
|
||||
struct ggml_tensor * logit_biased = ggml_add(ctx, data->logits, sctx->inp_logit_bias);
|
||||
data->logits = logit_biased;
|
||||
//struct ggml_tensor * logit_biased = ggml_add(ctx, data->logits, sctx->inp_logit_bias);
|
||||
ggml_tensor * cur = ggml_fill(ctx, data->logits, 0.0f);
|
||||
|
||||
cur = ggml_reshape_2d(ctx, cur, 1, ggml_nelements(cur));
|
||||
cur = ggml_set_rows(ctx, cur, sctx->inp_logit_bias, sctx->inp_logit_idxs);
|
||||
cur = ggml_reshape_1d(ctx, cur, ggml_nelements(cur));
|
||||
|
||||
data->logits = ggml_add(ctx, data->logits, cur);
|
||||
}
|
||||
|
||||
static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * smpl) {
|
||||
|
|
@ -3298,16 +3304,23 @@ static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * sm
|
|||
if (sctx->logit_bias.empty()) {
|
||||
return;
|
||||
}
|
||||
GGML_ASSERT(sctx->inp_logit_bias != nullptr);
|
||||
|
||||
// Create a sparse logit_bias vector from the logit_bias entries.
|
||||
std::vector<float> logit_bias_sparse(sctx->n_vocab, 0.0f);
|
||||
for (const auto & lb : sctx->logit_bias) {
|
||||
GGML_ASSERT(sctx->inp_logit_bias != nullptr);
|
||||
GGML_ASSERT(sctx->inp_logit_idxs != nullptr);
|
||||
|
||||
const size_t n = sctx->logit_bias.size();
|
||||
|
||||
std::vector<float> data_logit_bias(n, 0.0f);
|
||||
std::vector<int32_t> data_logit_idxs(n, 0);
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
const auto & lb = sctx->logit_bias[i];
|
||||
GGML_ASSERT(lb.token >= 0 && lb.token < (int32_t) sctx->n_vocab);
|
||||
logit_bias_sparse[lb.token] = lb.bias;
|
||||
data_logit_bias[i] = lb.bias;
|
||||
data_logit_idxs[i] = lb.token;
|
||||
}
|
||||
|
||||
ggml_backend_tensor_set(sctx->inp_logit_bias, logit_bias_sparse.data(), 0, ggml_nbytes(sctx->inp_logit_bias));
|
||||
ggml_backend_tensor_set(sctx->inp_logit_bias, data_logit_bias.data(), 0, ggml_nbytes(sctx->inp_logit_bias));
|
||||
ggml_backend_tensor_set(sctx->inp_logit_idxs, data_logit_idxs.data(), 0, ggml_nbytes(sctx->inp_logit_idxs));
|
||||
}
|
||||
|
||||
static bool llama_sampler_logit_bias_backend_init(
|
||||
|
|
@ -3322,17 +3335,23 @@ static bool llama_sampler_logit_bias_backend_init(
|
|||
}
|
||||
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/ ggml_tensor_overhead(),
|
||||
/*.mem_size =*/ 2*ggml_tensor_overhead(),
|
||||
/*.mem_buffer =*/ nullptr,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
sctx->inp_ctx.reset(ggml_init(params));
|
||||
|
||||
sctx->inp_logit_bias = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, sctx->n_vocab);
|
||||
const size_t n = sctx->logit_bias.size();
|
||||
|
||||
sctx->inp_logit_bias = ggml_new_tensor_2d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1, n);
|
||||
ggml_set_name(sctx->inp_logit_bias, "logit_bias");
|
||||
ggml_set_input(sctx->inp_logit_bias);
|
||||
|
||||
sctx->inp_logit_idxs = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_I32, n);
|
||||
ggml_set_name(sctx->inp_logit_idxs, "logit_idxs");
|
||||
ggml_set_input(sctx->inp_logit_idxs);
|
||||
|
||||
// Allocate all tensors from our context to the backend
|
||||
sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue