sampling : optimize logit_bias sampler
This commit is contained in:
parent
56720f8f01
commit
54e9054017
|
|
@ -3223,6 +3223,7 @@ struct llama_sampler_logit_bias : public llama_sampler_backend {
|
||||||
std::vector<llama_logit_bias> to_search;
|
std::vector<llama_logit_bias> to_search;
|
||||||
|
|
||||||
struct ggml_tensor * inp_logit_bias;
|
struct ggml_tensor * inp_logit_bias;
|
||||||
|
struct ggml_tensor * inp_logit_idxs;
|
||||||
|
|
||||||
ggml_context_ptr inp_ctx;
|
ggml_context_ptr inp_ctx;
|
||||||
ggml_backend_buffer_ptr inp_buf;
|
ggml_backend_buffer_ptr inp_buf;
|
||||||
|
|
@ -3288,9 +3289,14 @@ static void llama_sampler_logit_bias_backend_apply(
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the sparse logit logit_bias to the logits
|
//struct ggml_tensor * logit_biased = ggml_add(ctx, data->logits, sctx->inp_logit_bias);
|
||||||
struct ggml_tensor * logit_biased = ggml_add(ctx, data->logits, sctx->inp_logit_bias);
|
ggml_tensor * cur = ggml_fill(ctx, data->logits, 0.0f);
|
||||||
data->logits = logit_biased;
|
|
||||||
|
cur = ggml_reshape_2d(ctx, cur, 1, ggml_nelements(cur));
|
||||||
|
cur = ggml_set_rows(ctx, cur, sctx->inp_logit_bias, sctx->inp_logit_idxs);
|
||||||
|
cur = ggml_reshape_1d(ctx, cur, ggml_nelements(cur));
|
||||||
|
|
||||||
|
data->logits = ggml_add(ctx, data->logits, cur);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * smpl) {
|
static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * smpl) {
|
||||||
|
|
@ -3298,16 +3304,23 @@ static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * sm
|
||||||
if (sctx->logit_bias.empty()) {
|
if (sctx->logit_bias.empty()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
GGML_ASSERT(sctx->inp_logit_bias != nullptr);
|
|
||||||
|
|
||||||
// Create a sparse logit_bias vector from the logit_bias entries.
|
GGML_ASSERT(sctx->inp_logit_bias != nullptr);
|
||||||
std::vector<float> logit_bias_sparse(sctx->n_vocab, 0.0f);
|
GGML_ASSERT(sctx->inp_logit_idxs != nullptr);
|
||||||
for (const auto & lb : sctx->logit_bias) {
|
|
||||||
|
const size_t n = sctx->logit_bias.size();
|
||||||
|
|
||||||
|
std::vector<float> data_logit_bias(n, 0.0f);
|
||||||
|
std::vector<int32_t> data_logit_idxs(n, 0);
|
||||||
|
for (size_t i = 0; i < n; ++i) {
|
||||||
|
const auto & lb = sctx->logit_bias[i];
|
||||||
GGML_ASSERT(lb.token >= 0 && lb.token < (int32_t) sctx->n_vocab);
|
GGML_ASSERT(lb.token >= 0 && lb.token < (int32_t) sctx->n_vocab);
|
||||||
logit_bias_sparse[lb.token] = lb.bias;
|
data_logit_bias[i] = lb.bias;
|
||||||
|
data_logit_idxs[i] = lb.token;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_backend_tensor_set(sctx->inp_logit_bias, logit_bias_sparse.data(), 0, ggml_nbytes(sctx->inp_logit_bias));
|
ggml_backend_tensor_set(sctx->inp_logit_bias, data_logit_bias.data(), 0, ggml_nbytes(sctx->inp_logit_bias));
|
||||||
|
ggml_backend_tensor_set(sctx->inp_logit_idxs, data_logit_idxs.data(), 0, ggml_nbytes(sctx->inp_logit_idxs));
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool llama_sampler_logit_bias_backend_init(
|
static bool llama_sampler_logit_bias_backend_init(
|
||||||
|
|
@ -3322,17 +3335,23 @@ static bool llama_sampler_logit_bias_backend_init(
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_init_params params = {
|
ggml_init_params params = {
|
||||||
/*.mem_size =*/ ggml_tensor_overhead(),
|
/*.mem_size =*/ 2*ggml_tensor_overhead(),
|
||||||
/*.mem_buffer =*/ nullptr,
|
/*.mem_buffer =*/ nullptr,
|
||||||
/*.no_alloc =*/ true,
|
/*.no_alloc =*/ true,
|
||||||
};
|
};
|
||||||
|
|
||||||
sctx->inp_ctx.reset(ggml_init(params));
|
sctx->inp_ctx.reset(ggml_init(params));
|
||||||
|
|
||||||
sctx->inp_logit_bias = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, sctx->n_vocab);
|
const size_t n = sctx->logit_bias.size();
|
||||||
|
|
||||||
|
sctx->inp_logit_bias = ggml_new_tensor_2d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1, n);
|
||||||
ggml_set_name(sctx->inp_logit_bias, "logit_bias");
|
ggml_set_name(sctx->inp_logit_bias, "logit_bias");
|
||||||
ggml_set_input(sctx->inp_logit_bias);
|
ggml_set_input(sctx->inp_logit_bias);
|
||||||
|
|
||||||
|
sctx->inp_logit_idxs = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_I32, n);
|
||||||
|
ggml_set_name(sctx->inp_logit_idxs, "logit_idxs");
|
||||||
|
ggml_set_input(sctx->inp_logit_idxs);
|
||||||
|
|
||||||
// Allocate all tensors from our context to the backend
|
// Allocate all tensors from our context to the backend
|
||||||
sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
|
sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue