diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 2c1127666f..32e9fa5ed1 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -3223,6 +3223,7 @@ struct llama_sampler_logit_bias : public llama_sampler_backend { std::vector to_search; struct ggml_tensor * inp_logit_bias; + struct ggml_tensor * inp_logit_idxs; ggml_context_ptr inp_ctx; ggml_backend_buffer_ptr inp_buf; @@ -3288,9 +3289,14 @@ static void llama_sampler_logit_bias_backend_apply( return; } - // Add the sparse logit logit_bias to the logits - struct ggml_tensor * logit_biased = ggml_add(ctx, data->logits, sctx->inp_logit_bias); - data->logits = logit_biased; + //struct ggml_tensor * logit_biased = ggml_add(ctx, data->logits, sctx->inp_logit_bias); + ggml_tensor * cur = ggml_fill(ctx, data->logits, 0.0f); + + cur = ggml_reshape_2d(ctx, cur, 1, ggml_nelements(cur)); + cur = ggml_set_rows(ctx, cur, sctx->inp_logit_bias, sctx->inp_logit_idxs); + cur = ggml_reshape_1d(ctx, cur, ggml_nelements(cur)); + + data->logits = ggml_add(ctx, data->logits, cur); } static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * smpl) { @@ -3298,16 +3304,23 @@ static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * sm if (sctx->logit_bias.empty()) { return; } - GGML_ASSERT(sctx->inp_logit_bias != nullptr); - // Create a sparse logit_bias vector from the logit_bias entries. - std::vector logit_bias_sparse(sctx->n_vocab, 0.0f); - for (const auto & lb : sctx->logit_bias) { + GGML_ASSERT(sctx->inp_logit_bias != nullptr); + GGML_ASSERT(sctx->inp_logit_idxs != nullptr); + + const size_t n = sctx->logit_bias.size(); + + std::vector data_logit_bias(n, 0.0f); + std::vector data_logit_idxs(n, 0); + for (size_t i = 0; i < n; ++i) { + const auto & lb = sctx->logit_bias[i]; GGML_ASSERT(lb.token >= 0 && lb.token < (int32_t) sctx->n_vocab); - logit_bias_sparse[lb.token] = lb.bias; + data_logit_bias[i] = lb.bias; + data_logit_idxs[i] = lb.token; } - ggml_backend_tensor_set(sctx->inp_logit_bias, logit_bias_sparse.data(), 0, ggml_nbytes(sctx->inp_logit_bias)); + ggml_backend_tensor_set(sctx->inp_logit_bias, data_logit_bias.data(), 0, ggml_nbytes(sctx->inp_logit_bias)); + ggml_backend_tensor_set(sctx->inp_logit_idxs, data_logit_idxs.data(), 0, ggml_nbytes(sctx->inp_logit_idxs)); } static bool llama_sampler_logit_bias_backend_init( @@ -3322,17 +3335,23 @@ static bool llama_sampler_logit_bias_backend_init( } ggml_init_params params = { - /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_size =*/ 2*ggml_tensor_overhead(), /*.mem_buffer =*/ nullptr, /*.no_alloc =*/ true, }; sctx->inp_ctx.reset(ggml_init(params)); - sctx->inp_logit_bias = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, sctx->n_vocab); + const size_t n = sctx->logit_bias.size(); + + sctx->inp_logit_bias = ggml_new_tensor_2d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1, n); ggml_set_name(sctx->inp_logit_bias, "logit_bias"); ggml_set_input(sctx->inp_logit_bias); + sctx->inp_logit_idxs = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_I32, n); + ggml_set_name(sctx->inp_logit_idxs, "logit_idxs"); + ggml_set_input(sctx->inp_logit_idxs); + // Allocate all tensors from our context to the backend sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));