llama : simplify

This commit is contained in:
Georgi Gerganov 2025-11-29 15:58:59 +02:00
parent 2464d1b3fc
commit fbc8f49f3c
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
7 changed files with 595 additions and 967 deletions

View File

@ -347,19 +347,11 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
is_backend = is_backend && (params.samplers.size() == 0 || common_sampler_type_has_backend_support(params.samplers[idx_smpl]));
if (params.has_logit_bias()) {
if (is_backend) {
llama_sampler_chain_add(result->chain_backend,
llama_sampler_backend_init_logit_bias(
llama_vocab_n_tokens(vocab),
params.logit_bias.size(),
params.logit_bias.data()));
} else {
llama_sampler_chain_add(result->chain,
llama_sampler_init_logit_bias(
llama_vocab_n_tokens(vocab),
params.logit_bias.size(),
params.logit_bias.data()));
}
llama_sampler_chain_add(is_backend ? result->chain_backend : result->chain,
llama_sampler_init_logit_bias(
llama_vocab_n_tokens(vocab),
params.logit_bias.size(),
params.logit_bias.data()));
}
if (params.mirostat == 0) {
@ -375,16 +367,16 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
switch (cnstr) {
case COMMON_SAMPLER_TYPE_TOP_K:
llama_sampler_chain_add(result->chain_backend, llama_sampler_backend_init_top_k(params.top_k));
llama_sampler_chain_add(result->chain_backend, llama_sampler_init_top_k(params.top_k));
break;
case COMMON_SAMPLER_TYPE_TEMPERATURE:
llama_sampler_chain_add(result->chain_backend, llama_sampler_backend_init_temp(params.temp));
llama_sampler_chain_add(result->chain_backend, llama_sampler_init_temp(params.temp));
break;
case COMMON_SAMPLER_TYPE_MIN_P:
llama_sampler_chain_add(result->chain_backend, llama_sampler_backend_init_min_p(params.min_p));
llama_sampler_chain_add(result->chain_backend, llama_sampler_init_min_p(params.min_p, params.min_keep));
break;
case COMMON_SAMPLER_TYPE_TOP_P:
llama_sampler_chain_add(result->chain_backend, llama_sampler_backend_init_top_p(params.top_p));
llama_sampler_chain_add(result->chain_backend, llama_sampler_init_top_p(params.top_p, params.min_keep));
break;
default:
GGML_ASSERT(false && "unsupported backend sampler");
@ -439,11 +431,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
}
}
if (is_backend) {
llama_sampler_chain_add(result->chain_backend, llama_sampler_backend_init_dist(params.seed));
} else {
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
}
llama_sampler_chain_add(is_backend ? result->chain_backend : result->chain, llama_sampler_init_dist(params.seed));
} else if (params.mirostat == 1) {
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));

View File

@ -74,9 +74,9 @@ int main(int argc, char ** argv) {
llama_sampler * smpl = llama_sampler_chain_init(sparams);
if (params.sampling.backend_sampling) {
llama_sampler_chain_add(smpl, llama_sampler_backend_init_top_k(params.sampling.top_k));
llama_sampler_chain_add(smpl, llama_sampler_backend_init_temp (params.sampling.temp));
llama_sampler_chain_add(smpl, llama_sampler_backend_init_dist (params.sampling.seed));
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k));
llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp));
llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed));
} else {
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k));
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep));

View File

@ -1388,31 +1388,9 @@ extern "C" {
// Backend samplers
//
/// @details Greedy sampling on backend - always selects the token with the highest probability
LLAMA_API struct llama_sampler * llama_sampler_backend_init_greedy(void);
/// @details Temperature scaling on backend - scales logits by 1/temperature
LLAMA_API struct llama_sampler * llama_sampler_backend_init_temp(float temp);
/// @details Top-K filtering on backend - keeps only the k tokens with highest probabilities
LLAMA_API struct llama_sampler * llama_sampler_backend_init_top_k(int32_t k);
/// @details Distribution sampling on backend - final sampling step that selects a token
LLAMA_API struct llama_sampler * llama_sampler_backend_init_dist(uint32_t seed);
/// @details Min-P filtering on backend - filter tokens with a probability less than p times the maximum probability.
LLAMA_API struct llama_sampler * llama_sampler_backend_init_min_p(float p);
/// @details Top-p filtering on backend - filter all tokens with cumulative pseudo-probability less than p.
LLAMA_API struct llama_sampler * llama_sampler_backend_init_top_p(float p);
// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
LLAMA_API struct llama_sampler * llama_sampler_backend_init_logit_bias(int32_t n_vocab,
int32_t n_logit_bias,
const llama_logit_bias * logit_bias);
/// @details Sample and accept a token from the idx-th output of the last evaluation
//
// Shorthand for:

View File

@ -31,7 +31,6 @@ add_library(llama
llama-model.cpp
llama-quant.cpp
llama-sampling.cpp
llama-backend-sampler.cpp
llama-vocab.cpp
unicode-data.cpp
unicode.cpp

View File

@ -1,711 +0,0 @@
#include "llama.h"
#include "ggml.h"
#include <cstdio>
#include <chrono>
#include <random>
#include <unordered_map>
#include <vector>
static void llama_sampler_backend_greedy_apply_ggml(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_ggml_data * ggml_data) {
GGML_UNUSED(gf);
GGML_UNUSED(smpl);
struct ggml_tensor * argmax_result = ggml_argmax(ctx, ggml_data->logits);
ggml_set_name(argmax_result, "argmax_result");
ggml_data->sampled = argmax_result;
}
static const char * llama_sampler_backend_greedy_sampler_name(const struct llama_sampler *) {
return "test-ggml";
}
static struct llama_sampler * llama_sampler_backend_greedy_clone(const struct llama_sampler * smpl) {
(void) smpl;
return llama_sampler_backend_init_greedy();
}
struct llama_sampler * llama_sampler_backend_init_greedy() {
static const llama_sampler_i iface = {
/*.name =*/ llama_sampler_backend_greedy_sampler_name,
/*.accept =*/ nullptr,
/*.apply =*/ nullptr,
/*.reset =*/ nullptr,
/*.clone =*/ llama_sampler_backend_greedy_clone,
/*.free =*/ nullptr,
/*.apply_ggml =*/ llama_sampler_backend_greedy_apply_ggml,
/*.accept_ggml =*/ nullptr,
/*.set_input_ggml =*/ nullptr,
/*.init_ggml =*/ nullptr,
};
auto * sampler = new llama_sampler {
/*.iface =*/ &iface,
/*.ctx =*/ nullptr,
};
return sampler;
}
struct llama_sampler_backend_temp_ctx {
float temp;
};
static void llama_sampler_backend_temp_apply_ggml(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_ggml_data * ggml_data) {
auto * ctx_data = (llama_sampler_backend_temp_ctx *) smpl->ctx;
if (ctx_data->temp <= 0.0f) {
return;
}
struct ggml_tensor * scaled = ggml_scale(ctx, ggml_data->logits, 1.0f / ctx_data->temp);
ggml_set_name(scaled, "temp_scaled");
// Make sure the scaled tensor is contiguous for subsequent operations
ggml_data->logits = ggml_cont(ctx, scaled);
ggml_set_name(ggml_data->logits, "temp_scaled_logits");
ggml_build_forward_expand(gf, ggml_data->logits);
}
static const char * llama_sampler_backend_temp_name(const struct llama_sampler *) {
return "backend-temp";
}
static void llama_sampler_backend_temp_free(struct llama_sampler * smpl) {
auto * ctx_data = (llama_sampler_backend_temp_ctx *) smpl->ctx;
delete ctx_data;
}
static struct llama_sampler * llama_sampler_backend_temp_clone(const struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_backend_temp_ctx *) smpl->ctx;
return llama_sampler_backend_init_temp(ctx->temp);
}
struct llama_sampler * llama_sampler_backend_init_temp(float temp) {
static const llama_sampler_i iface = {
/*.name =*/ llama_sampler_backend_temp_name,
/*.accept =*/ nullptr,
/*.apply =*/ nullptr,
/*.reset =*/ nullptr,
/*.clone =*/ llama_sampler_backend_temp_clone,
/*.free =*/ llama_sampler_backend_temp_free,
/*.apply_ggml =*/ llama_sampler_backend_temp_apply_ggml,
/*.accept_ggml =*/ nullptr,
/*.set_input_ggml =*/ nullptr,
/*.set_backend_context =*/ nullptr,
};
auto * ctx_data = new llama_sampler_backend_temp_ctx {
/*.temp =*/ temp,
};
auto * sampler = new llama_sampler {
/*.iface =*/ &iface,
/*.ctx =*/ ctx_data,
};
return sampler;
}
struct llama_sampler_backend_top_k_ctx {
int32_t k;
// Only required for checking operation support and can be removed later.
ggml_backend_dev_t device;
};
static void llama_sampler_backend_top_k_init_ggml(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
auto * ctx_data = (llama_sampler_backend_top_k_ctx *) smpl->ctx;
ctx_data->device = ggml_backend_buft_get_device(buft);
}
static void llama_sampler_backend_top_k_apply_ggml(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_ggml_data * ggml_data) {
auto * ctx_data = (llama_sampler_backend_top_k_ctx *) smpl->ctx;
struct ggml_tensor * top_k = ggml_top_k(ctx, ggml_data->logits, ctx_data->k);
ggml_set_name(top_k, "top_k");
// top_k is a view of argsort - check if backend supports the underlying argsort operation
// by checking the source tensor (which is the argsort result)
if (ctx_data->device && top_k->src[0] && !ggml_backend_dev_supports_op(ctx_data->device, top_k->src[0])) {
fprintf(stderr, "Warning: backend does not support argsort operation required for top-k sampling\n");
fprintf(stderr, "CPU backend will be used instead which defeats the purpose of having backend samplers\n");
}
ggml_data->candidates = top_k;
struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, ggml_data->logits, 1, ggml_data->logits->ne[0]);
struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k);
ggml_set_name(top_k_rows, "top_k_rows");
ggml_data->logits = ggml_reshape_1d(ctx, top_k_rows, ctx_data->k);
ggml_build_forward_expand(gf, ggml_data->logits);
}
static const char * llama_sampler_backend_top_k_name(const struct llama_sampler *) {
return "backend-top-k";
}
static void llama_sampler_backend_top_k_free(struct llama_sampler * smpl) {
auto * ctx_data = (llama_sampler_backend_top_k_ctx *) smpl->ctx;
delete ctx_data;
}
static struct llama_sampler * llama_sampler_backend_top_k_clone(const struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_backend_top_k_ctx *) smpl->ctx;
return llama_sampler_backend_init_top_k(ctx->k);
}
struct llama_sampler * llama_sampler_backend_init_top_k(int32_t k) {
static const llama_sampler_i iface = {
/*.name =*/ llama_sampler_backend_top_k_name,
/*.accept =*/ nullptr,
/*.apply =*/ nullptr,
/*.reset =*/ nullptr,
/*.clone =*/ llama_sampler_backend_top_k_clone,
/*.free =*/ llama_sampler_backend_top_k_free,
/*.apply_ggml =*/ llama_sampler_backend_top_k_apply_ggml,
/*.accept_ggml =*/ nullptr,
/*.set_input_ggml =*/ nullptr,
/*.init_ggml =*/ llama_sampler_backend_top_k_init_ggml,
};
auto * ctx_data = new llama_sampler_backend_top_k_ctx {
/*.k =*/ k,
/*.device =*/ nullptr,
};
auto * sampler = new llama_sampler {
/*.iface =*/ &iface,
/*.ctx =*/ ctx_data,
};
return sampler;
}
static uint32_t get_rng_seed(uint32_t seed) {
if (seed == LLAMA_DEFAULT_SEED) {
// use system clock if std::random_device is not a true RNG
static bool is_rd_prng = std::random_device().entropy() == 0;
if (is_rd_prng) {
return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count();
}
std::random_device rd;
return rd();
}
return seed;
}
struct llama_sampler_backend_dist_ctx {
const uint32_t seed;
uint32_t seed_cur;
std::mt19937 rng;
struct ggml_tensor * uniform;
struct ggml_context * ctx;
ggml_backend_buffer_t buffer;
// Only required for checking operation support and can be removed later.
ggml_backend_dev_t device;
};
static void llama_sampler_backend_dist_init_ggml(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
auto * sctx = (llama_sampler_backend_dist_ctx *) smpl->ctx;
sctx->device = ggml_backend_buft_get_device(buft);
ggml_init_params params = {
/*.mem_size =*/ ggml_tensor_overhead(),
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};
sctx->ctx = ggml_init(params);
// Create the uniform random scalar input tensor. This will be set by
// llama_sampler_backend_dist_set_input_ggml after this graph is built.
sctx->uniform = ggml_new_tensor_1d(sctx->ctx, GGML_TYPE_F32, 1);
ggml_set_name(sctx->uniform, "uniform");
ggml_set_input(sctx->uniform);
ggml_set_output(sctx->uniform);
// Allocate all tensors from our context to the backend
sctx->buffer = ggml_backend_alloc_ctx_tensors_from_buft(sctx->ctx, buft);
}
static void llama_sampler_backend_dist_set_input_ggml(struct llama_sampler * smpl) {
auto * sctx = (llama_sampler_backend_dist_ctx *) smpl->ctx;
GGML_ASSERT(sctx->uniform != nullptr);
std::uniform_real_distribution<float> dist(0.0f, 1.0f);
const float rnd = dist(sctx->rng);
ggml_backend_tensor_set(sctx->uniform, &rnd, 0, sizeof(float));
}
static void llama_sampler_backend_dist_apply_ggml(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_ggml_data * ggml_data) {
GGML_UNUSED(gf);
auto * sctx = (llama_sampler_backend_dist_ctx *) smpl->ctx;
struct ggml_tensor * probs = ggml_soft_max(ctx, ggml_data->logits);
ggml_set_name(probs, "dist_probs");
struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs);
if (sctx->device && !ggml_backend_dev_supports_op(sctx->device, cumsum)) {
fprintf(stderr, "Warning: backend does not support cumsum operation required for dist sampling\n");
fprintf(stderr, "CPU backend will be used instead which defeats the purpose of having backend samplers\n");
}
ggml_set_name(cumsum, "cumsum");
// The uniform tensor has a random value and we subtract this tensor with
// the cumsum tensor (the uniform tensor will be broadcasted by ggml_sub).
// Recall that each entry in cumsum is the cumulative probability up to that
// index so values stay negative while the cumulative total is below the
// random value, and become zero/positive once the threshold is crossed.
struct ggml_tensor * diff = ggml_sub(ctx, cumsum, sctx->uniform);
ggml_set_name(diff, "dist_cumsum");
// The ggml_step function produces a tensor where entries are 1 if the
// corresponding entry in diff is > 0, and 0 otherwise. So all values up to
// the index where the cumulative probability exceeds the random value are 0,
// and all entries after that are 1.
struct ggml_tensor * mask = ggml_step(ctx, diff);
ggml_set_name(mask, "dist_mask");
// Taking the sum of the mask gives us the sum of elements after the threshold
// we are interested in.
struct ggml_tensor * idxf = ggml_sum(ctx, mask);
ggml_set_name(idxf, "dist_index_f32");
// Use ggml_scale_bias to scale the index value by -1 and then add the size
// of the mask to that value so we get the correct index ((-1 * idxf) + n).
struct ggml_tensor * idx = ggml_cast(ctx, ggml_scale_bias(ctx, idxf, -1.0f, mask->ne[0]), GGML_TYPE_I32);
ggml_set_name(idx, "dist_index_i32");
// Map back to original vocab ids if a candidates tensor is available.
struct ggml_tensor * sampled_token = idx;
if (ggml_data->candidates != nullptr) {
struct ggml_tensor * candidates = ggml_data->candidates;
struct ggml_tensor * candidates_reshaped = ggml_view_2d(ctx, candidates, 1, ggml_nelements(candidates),
ggml_type_size(candidates->type), 0);
sampled_token = ggml_get_rows(ctx, candidates_reshaped, idx);
ggml_set_name(sampled_token, "dist_sampled_token");
}
ggml_set_output(sampled_token);
ggml_data->sampled = sampled_token;
}
static const char * llama_sampler_backend_dist_name(const struct llama_sampler *) {
return "backend-dist";
}
static void llama_sampler_backend_dist_free(struct llama_sampler * smpl) {
auto * sctx = (llama_sampler_backend_dist_ctx *) smpl->ctx;
ggml_backend_buffer_free(sctx->buffer);
ggml_free(sctx->ctx);
delete sctx;
}
static struct llama_sampler * llama_sampler_backend_dist_clone(const struct llama_sampler * smpl) {
auto * sctx = (llama_sampler_backend_dist_ctx *) smpl->ctx;
return llama_sampler_backend_init_dist(sctx->seed);
}
struct llama_sampler * llama_sampler_backend_init_dist(uint32_t seed) {
static const llama_sampler_i iface = {
/*.name =*/ llama_sampler_backend_dist_name,
/*.accept =*/ nullptr,
/*.apply =*/ nullptr,
/*.reset =*/ nullptr,
/*.clone =*/ llama_sampler_backend_dist_clone,
/*.free =*/ llama_sampler_backend_dist_free,
/*.apply_ggml =*/ llama_sampler_backend_dist_apply_ggml,
/*.accept_ggml =*/ nullptr,
/*.set_input_ggml =*/ llama_sampler_backend_dist_set_input_ggml,
/*.init_ggml =*/ llama_sampler_backend_dist_init_ggml,
};
auto seed_cur = get_rng_seed(seed);
auto * ctx_data = new llama_sampler_backend_dist_ctx {
/*.seed =*/ seed,
/*.seed_cur =*/ seed_cur,
/*.rng =*/ std::mt19937(seed_cur),
/*.uniform =*/ nullptr,
/*.ctx =*/ nullptr,
/*.buffer =*/ nullptr,
/*.device =*/ nullptr,
};
auto * sampler = new llama_sampler {
/*.iface =*/ &iface,
/*.ctx =*/ ctx_data,
};
return sampler;
}
struct llama_sampler_backend_logit_bias_ctx {
const int32_t n_vocab;
const std::vector<llama_logit_bias> logit_bias;
struct ggml_tensor * logit_bias_t;
struct ggml_context * ctx;
ggml_backend_buffer_t buffer;
};
static void llama_sampler_backend_logit_bias_init_ggml(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
auto * sctx = (llama_sampler_backend_logit_bias_ctx *) smpl->ctx;
if (sctx->logit_bias.empty()) {
return;
}
ggml_init_params params = {
/*.mem_size =*/ ggml_tensor_overhead() * sctx->n_vocab * sizeof(float),
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};
sctx->ctx = ggml_init(params);
struct ggml_tensor * logit_bias = ggml_new_tensor_1d(sctx->ctx, GGML_TYPE_F32, sctx->n_vocab);
sctx->logit_bias_t = logit_bias;
ggml_set_name(sctx->logit_bias_t, "logit_bias");
ggml_set_input(sctx->logit_bias_t);
ggml_set_output(sctx->logit_bias_t);
// Allocate all tensors from our context to the backend
sctx->buffer = ggml_backend_alloc_ctx_tensors_from_buft(sctx->ctx, buft);
}
static void llama_sampler_backend_logit_bias_set_input_ggml(struct llama_sampler * smpl) {
auto * sctx = (llama_sampler_backend_logit_bias_ctx *) smpl->ctx;
if (sctx->logit_bias.empty()) {
return;
}
GGML_ASSERT(sctx->logit_bias_t != nullptr);
// Create a sparse logit_bias vector from the logit_bias entries.
std::vector<float> logit_bias_sparse(sctx->n_vocab, 0.0f);
for (const auto & lb : sctx->logit_bias) {
GGML_ASSERT(lb.token >= 0 && lb.token < (int32_t) sctx->n_vocab);
logit_bias_sparse[lb.token] = lb.bias;
}
ggml_backend_tensor_set(sctx->logit_bias_t, logit_bias_sparse.data(), 0, ggml_nbytes(sctx->logit_bias_t));
}
static void llama_sampler_backend_logit_bias_apply_ggml(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_ggml_data * ggml_data) {
GGML_UNUSED(gf);
GGML_UNUSED(ctx);
auto * sctx = (llama_sampler_backend_logit_bias_ctx *) smpl->ctx;
if (sctx->logit_bias_t == nullptr) {
return;
}
// Add the sparse logit logit_bias to the logits
struct ggml_tensor * logit_biased = ggml_add_inplace(sctx->ctx, ggml_data->logits, sctx->logit_bias_t);
ggml_build_forward_expand(gf, logit_biased);
}
static const char * llama_sampler_backend_logit_bias_name(const struct llama_sampler *) {
return "backend-logit_bias";
}
static void llama_sampler_backend_logit_bias_free(struct llama_sampler * smpl) {
auto * sctx = (llama_sampler_backend_logit_bias_ctx *) smpl->ctx;
ggml_backend_buffer_free(sctx->buffer);
ggml_free(sctx->ctx);
delete sctx;
}
static struct llama_sampler * llama_sampler_backend_logit_bias_clone(const struct llama_sampler * smpl) {
auto * sctx = (llama_sampler_backend_logit_bias_ctx *) smpl->ctx;
return llama_sampler_backend_init_logit_bias(sctx->n_vocab,
sctx->logit_bias.size(),
sctx->logit_bias.data());
}
struct llama_sampler * llama_sampler_backend_init_logit_bias(int32_t n_vocab,
int32_t n_logit_bias,
const llama_logit_bias * logit_bias) {
static const llama_sampler_i iface = {
/*.name =*/ llama_sampler_backend_logit_bias_name,
/*.accept =*/ nullptr,
/*.apply =*/ nullptr,
/*.reset =*/ nullptr,
/*.clone =*/ llama_sampler_backend_logit_bias_clone,
/*.free =*/ llama_sampler_backend_logit_bias_free,
/*.apply_ggml =*/ llama_sampler_backend_logit_bias_apply_ggml,
/*.accept_ggml =*/ nullptr,
/*.set_input_ggml =*/ llama_sampler_backend_logit_bias_set_input_ggml,
/*.init_ggml =*/ llama_sampler_backend_logit_bias_init_ggml,
};
auto * ctx_data = new llama_sampler_backend_logit_bias_ctx {
/*.n_vocab =*/ n_vocab,
/*.logit_bias =*/ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
/*.logit_bias_t =*/ nullptr,
/*.ctx =*/ nullptr,
/*.buffer =*/ nullptr,
};
auto * sampler = new llama_sampler {
/*.iface =*/ &iface,
/*.ctx =*/ ctx_data,
};
return sampler;
}
struct llama_sampler_backend_min_p_ctx {
float p;
// Only required for checking operation support and can be removed later.
ggml_backend_dev_t device;
};
static void llama_sampler_backend_min_p_init_ggml(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
auto * sctx = (llama_sampler_backend_min_p_ctx *) smpl->ctx;
sctx->device = ggml_backend_buft_get_device(buft);
}
static void llama_sampler_backend_min_p_apply_ggml(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_ggml_data * ggml_data) {
GGML_UNUSED(gf);
auto * sctx = (llama_sampler_backend_min_p_ctx *) smpl->ctx;
struct ggml_tensor * max_idx = ggml_argmax(ctx, ggml_data->logits);
ggml_set_name(max_idx, "max_idx");
struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, ggml_data->logits, 1, ggml_data->logits->ne[0]);
ggml_set_name(logits_rows, "logits_rows");
struct ggml_tensor * max_logit = ggml_get_rows(ctx, logits_rows, max_idx);
ggml_set_name(max_logit, "max_logit");
// Calculate the threshold value.
struct ggml_tensor * threshold = ggml_scale_bias(ctx, max_logit, 1.0f, logf(sctx->p));
ggml_set_name(threshold, "min_p_threshold");
// Subtract the threshold from logits.
struct ggml_tensor * sub = ggml_sub(ctx, ggml_data->logits, threshold);
// Create a mask where logits below the threshold are 0 (discard),
// and others are 1 (keep).
struct ggml_tensor * mask = ggml_step(ctx, sub);
ggml_set_name(mask, "min_p_mask");
// Use ggml_scale_bias (output = (a * s) + b) which in this case becomes:
// min_p_bias = (mask * 1e9f) - 1e9f.
// So entries in the mask that we want to discard will become -1e9f, and
// others will be 0 (meaning that will not effect the logits).
const float large_val = 1e9f;
struct ggml_tensor * min_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val);
ggml_set_name(min_p_bias, "min_p_bias");
// Add the min_p bias to the logits.
ggml_data->logits = ggml_add(ctx, ggml_data->logits, min_p_bias);
ggml_set_name(ggml_data->logits, "min_p_logits");
ggml_build_forward_expand(gf, ggml_data->logits);
}
static const char * llama_sampler_backend_min_p_name(const struct llama_sampler *) {
return "backend-min-p";
}
static void llama_sampler_backend_min_p_free(struct llama_sampler * smpl) {
auto * sctx = (llama_sampler_backend_min_p_ctx *) smpl->ctx;
delete sctx;
}
static struct llama_sampler * llama_sampler_backend_min_p_clone(const struct llama_sampler * smpl) {
auto * sctx = (llama_sampler_backend_min_p_ctx *) smpl->ctx;
return llama_sampler_backend_init_min_p(sctx->p);
}
struct llama_sampler * llama_sampler_backend_init_min_p(float p) {
static const llama_sampler_i iface = {
/*.name =*/ llama_sampler_backend_min_p_name,
/*.accept =*/ nullptr,
/*.apply =*/ nullptr,
/*.reset =*/ nullptr,
/*.clone =*/ llama_sampler_backend_min_p_clone,
/*.free =*/ llama_sampler_backend_min_p_free,
/*.apply_ggml =*/ llama_sampler_backend_min_p_apply_ggml,
/*.accept_ggml =*/ nullptr,
/*.set_input_ggml =*/ nullptr,
/*.init_ggml =*/ llama_sampler_backend_min_p_init_ggml,
};
auto * sctx = new llama_sampler_backend_min_p_ctx {
/*.p =*/ p,
/*.device =*/ nullptr,
};
auto * sampler = new llama_sampler {
/*.iface =*/ &iface,
/*.ctx =*/ sctx,
};
return sampler;
}
struct llama_sampler_backend_top_p_ctx {
float p;
// Only required for checking operation support and can be removed later.
ggml_backend_dev_t device;
};
static void llama_sampler_backend_top_p_init_ggml(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
auto * sctx = (llama_sampler_backend_top_p_ctx *) smpl->ctx;
sctx->device = ggml_backend_buft_get_device(buft);
}
static void llama_sampler_backend_top_p_apply_ggml(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_ggml_data * ggml_data) {
GGML_UNUSED(gf);
auto * sctx = (llama_sampler_backend_top_p_ctx *) smpl->ctx;
struct ggml_tensor * softmax = ggml_soft_max(ctx, ggml_data->logits);
ggml_set_name(softmax, "top_p_softmax");
// Get the sorted indices of the softmax probabilities in descending order.
struct ggml_tensor * sorted_idx = ggml_argsort(ctx, softmax, GGML_SORT_ORDER_DESC);
ggml_set_name(sorted_idx, "top_p_sorted_idx");
// Do the sorting via reshape + get_rows
struct ggml_tensor * softmax_reshaped = ggml_reshape_2d(ctx, softmax, 1, softmax->ne[0]);
ggml_set_name(softmax_reshaped, "top_p_softmax_reshaped");
struct ggml_tensor * sorted_probs = ggml_get_rows(ctx, softmax_reshaped, sorted_idx);
ggml_set_name(sorted_probs, "top_p_sorted_probs");
struct ggml_tensor * sorted_probs_reshaped = ggml_reshape_2d(ctx, sorted_probs, softmax->ne[0], 1);
ggml_set_name(sorted_probs_reshaped, "top_p_sorted_probs_reshaped");
// Compute Cumulative Distribution Function (CDF) by means of GGML_OP_CUMSUM.
struct ggml_tensor * sorted_cdf = ggml_cumsum(ctx, sorted_probs_reshaped);
ggml_set_name(sorted_cdf, "top_p_sorted_cdf");
// Invert CDF and add top-p value so that ggml_step yields 1 for values we want to keep
struct ggml_tensor * sorted_cdf_scaled = ggml_scale_bias(ctx, sorted_cdf, -1.0f, sctx->p);
ggml_set_name(sorted_cdf_scaled, "top_p_sorted_cdf_scaled");
struct ggml_tensor * sorted_mask = ggml_step(ctx, sorted_cdf_scaled);
ggml_set_name(sorted_mask, "top_p_sorted_mask");
// reverse sorting by argsort(argsort)
// cast to F32 since cuda only supports float inputs
struct ggml_tensor * reverse_argsort = ggml_argsort(ctx, ggml_cast(ctx, sorted_idx, GGML_TYPE_F32), GGML_SORT_ORDER_ASC);
ggml_set_name(reverse_argsort, "top_p_reverse_argsort");
// Do the sorting via reshape + get_rows
struct ggml_tensor * sorted_reshaped_mask = ggml_reshape_2d(ctx, sorted_mask, 1, sorted_mask->ne[0]);
ggml_set_name(sorted_reshaped_mask, "top_p_sorted_reshaped_mask");
struct ggml_tensor * reshaped_mask = ggml_get_rows(ctx, sorted_reshaped_mask, reverse_argsort);
ggml_set_name(reshaped_mask, "top_p_reshaped_mask");
struct ggml_tensor * mask = ggml_reshape_2d(ctx, reshaped_mask, sorted_mask->ne[0], 1);
ggml_set_name(mask, "top_p_mask");
// Use ggml_scale_bias (output = (a * s) + b) which in this case becomes:
// top_p_bias = (mask * 1e9f) - 1e9f.
// So entries in the mask that we want to discard will become -1e9f, and
// others will be 0 (meaning that will not effect the logits).
const float large_val = 1e9f;
struct ggml_tensor * top_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val);
ggml_set_name(top_p_bias, "top_p_bias");
ggml_data->logits = ggml_add(ctx, ggml_data->logits, top_p_bias);
ggml_set_name(ggml_data->logits, "top_p_logits");
ggml_build_forward_expand(gf, ggml_data->logits);
}
static const char * llama_sampler_backend_top_p_name(const struct llama_sampler *) {
return "backend-top-p";
}
static void llama_sampler_backend_top_p_free(struct llama_sampler * smpl) {
auto * sctx = (llama_sampler_backend_top_p_ctx *) smpl->ctx;
delete sctx;
}
static struct llama_sampler * llama_sampler_backend_top_p_clone(const struct llama_sampler * smpl) {
auto * sctx = (llama_sampler_backend_top_p_ctx *) smpl->ctx;
return llama_sampler_backend_init_top_p(sctx->p);
}
struct llama_sampler * llama_sampler_backend_init_top_p(float p) {
static const llama_sampler_i iface = {
/*.name =*/ llama_sampler_backend_top_p_name,
/*.accept =*/ nullptr,
/*.apply =*/ nullptr,
/*.reset =*/ nullptr,
/*.clone =*/ llama_sampler_backend_top_p_clone,
/*.free =*/ llama_sampler_backend_top_p_free,
/*.apply_ggml =*/ llama_sampler_backend_top_p_apply_ggml,
/*.accept_ggml =*/ nullptr,
/*.set_input_ggml =*/ nullptr,
/*.init_ggml =*/ llama_sampler_backend_top_p_init_ggml,
};
auto * sctx = new llama_sampler_backend_top_p_ctx {
/*.p =*/ p,
/*.device =*/ nullptr,
};
auto * sampler = new llama_sampler {
/*.iface =*/ &iface,
/*.ctx =*/ sctx,
};
return sampler;
}

View File

@ -4,6 +4,8 @@
#include "llama-vocab.h"
#include "llama-grammar.h"
#include "ggml-cpp.h"
#include <array>
#include <algorithm>
#include <cassert>
@ -610,16 +612,16 @@ static void llama_sampler_chain_set_backend_context(
}
static struct llama_sampler_i llama_sampler_chain_i = {
/* .name = */ llama_sampler_chain_name,
/* .accept = */ llama_sampler_chain_accept,
/* .apply = */ llama_sampler_chain_apply,
/* .reset = */ llama_sampler_chain_reset,
/* .clone = */ llama_sampler_chain_clone,
/* .free = */ llama_sampler_chain_free,
/* .apply_ggml = */ llama_sampler_chain_apply_ggml,
/* .accept_ggml = */ llama_sampler_chain_accept_ggml,
/* .set_input_ggml = */ llama_sampler_chain_set_input_ggml,
/* .set_backend_context = */ llama_sampler_chain_set_backend_context,
/* .name = */ llama_sampler_chain_name,
/* .accept = */ llama_sampler_chain_accept,
/* .apply = */ llama_sampler_chain_apply,
/* .reset = */ llama_sampler_chain_reset,
/* .clone = */ llama_sampler_chain_clone,
/* .free = */ llama_sampler_chain_free,
/* .apply_ggml = */ llama_sampler_chain_apply_ggml,
/* .accept_ggml = */ llama_sampler_chain_accept_ggml,
/* .set_input_ggml = */ llama_sampler_chain_set_input_ggml,
/* .init_ggml = */ llama_sampler_chain_set_backend_context,
};
struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
@ -687,17 +689,29 @@ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_to
}
}
static void llama_sampler_greedy_apply_ggml(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_ggml_data * ggml_data) {
GGML_UNUSED(gf);
GGML_UNUSED(smpl);
struct ggml_tensor * argmax_result = ggml_argmax(ctx, ggml_data->logits);
ggml_set_name(argmax_result, "argmax_result");
ggml_data->sampled = argmax_result;
}
static struct llama_sampler_i llama_sampler_greedy_i = {
/* .name = */ llama_sampler_greedy_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_greedy_apply,
/* .reset = */ nullptr,
/* .clone = */ nullptr,
/* .free = */ nullptr,
/* .apply_ggml = */ nullptr,
/* .accept_ggml = */ nullptr,
/* .set_input_ggml = */ nullptr,
/* .set_backend_context = */ nullptr,
/* .name = */ llama_sampler_greedy_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_greedy_apply,
/* .reset = */ nullptr,
/* .clone = */ nullptr,
/* .free = */ nullptr,
/* .apply_ggml = */ llama_sampler_greedy_apply_ggml,
/* .accept_ggml = */ nullptr,
/* .set_input_ggml = */ nullptr,
/* .init_ggml = */ nullptr,
};
struct llama_sampler * llama_sampler_init_greedy() {
@ -714,6 +728,14 @@ struct llama_sampler_dist {
uint32_t seed_cur;
std::mt19937 rng;
// Only required for checking operation support and can be removed later.
ggml_backend_dev_t device;
struct ggml_tensor * inp_uniform;
ggml_context_ptr inp_ctx;
ggml_backend_buffer_ptr inp_buf;
};
static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) {
@ -816,17 +838,109 @@ static void llama_sampler_dist_free(struct llama_sampler * smpl) {
delete (llama_sampler_dist *) smpl->ctx;
}
static void llama_sampler_dist_apply_ggml(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_ggml_data * ggml_data) {
GGML_UNUSED(gf);
auto * sctx = (llama_sampler_dist *) smpl->ctx;
struct ggml_tensor * probs = ggml_soft_max(ctx, ggml_data->logits);
ggml_set_name(probs, "dist_probs");
struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs);
if (sctx->device && !ggml_backend_dev_supports_op(sctx->device, cumsum)) {
fprintf(stderr, "Warning: backend does not support cumsum operation required for dist sampling\n");
fprintf(stderr, "CPU backend will be used instead which defeats the purpose of having backend samplers\n");
}
ggml_set_name(cumsum, "cumsum");
// The uniform tensor has a random value and we subtract this tensor with
// the cumsum tensor (the uniform tensor will be broadcasted by ggml_sub).
// Recall that each entry in cumsum is the cumulative probability up to that
// index so values stay negative while the cumulative total is below the
// random value, and become zero/positive once the threshold is crossed.
struct ggml_tensor * diff = ggml_sub(ctx, cumsum, sctx->inp_uniform);
ggml_set_name(diff, "dist_cumsum");
// The ggml_step function produces a tensor where entries are 1 if the
// corresponding entry in diff is > 0, and 0 otherwise. So all values up to
// the index where the cumulative probability exceeds the random value are 0,
// and all entries after that are 1.
struct ggml_tensor * mask = ggml_step(ctx, diff);
ggml_set_name(mask, "dist_mask");
// Taking the sum of the mask gives us the sum of elements after the threshold
// we are interested in.
struct ggml_tensor * idxf = ggml_sum(ctx, mask);
ggml_set_name(idxf, "dist_index_f32");
// Use ggml_scale_bias to scale the index value by -1 and then add the size
// of the mask to that value so we get the correct index ((-1 * idxf) + n).
struct ggml_tensor * idx = ggml_cast(ctx, ggml_scale_bias(ctx, idxf, -1.0f, mask->ne[0]), GGML_TYPE_I32);
ggml_set_name(idx, "dist_index_i32");
// Map back to original vocab ids if a candidates tensor is available.
struct ggml_tensor * sampled_token = idx;
if (ggml_data->candidates != nullptr) {
struct ggml_tensor * candidates = ggml_data->candidates;
struct ggml_tensor * candidates_reshaped = ggml_view_2d(ctx, candidates, 1, ggml_nelements(candidates),
ggml_type_size(candidates->type), 0);
sampled_token = ggml_get_rows(ctx, candidates_reshaped, idx);
ggml_set_name(sampled_token, "dist_sampled_token");
}
ggml_set_output(sampled_token);
ggml_data->sampled = sampled_token;
}
static void llama_sampler_dist_set_input_ggml(struct llama_sampler * smpl) {
auto * sctx = (llama_sampler_dist *) smpl->ctx;
GGML_ASSERT(sctx->inp_uniform != nullptr);
std::uniform_real_distribution<float> dist(0.0f, 1.0f);
const float rnd = dist(sctx->rng);
ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float));
}
static void llama_sampler_dist_init_ggml(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
auto * sctx = (llama_sampler_dist *) smpl->ctx;
sctx->device = ggml_backend_buft_get_device(buft);
ggml_init_params params = {
/*.mem_size =*/ ggml_tensor_overhead(),
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};
sctx->inp_ctx.reset(ggml_init(params));
// Create the uniform random scalar input tensor. This will be set by
// llama_sampler_dist_set_input_ggml after this graph is built.
sctx->inp_uniform = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1);
ggml_set_name(sctx->inp_uniform, "uniform");
ggml_set_input(sctx->inp_uniform);
// Allocate all tensors from our context to the backend
sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
}
static struct llama_sampler_i llama_sampler_dist_i = {
/* .name = */ llama_sampler_dist_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_dist_apply,
/* .reset = */ llama_sampler_dist_reset,
/* .clone = */ llama_sampler_dist_clone,
/* .free = */ llama_sampler_dist_free,
/* .apply_ggml = */ nullptr,
/* .accept_ggml = */ nullptr,
/* .set_input_ggml = */ nullptr,
/* .set_backend_context = */ nullptr,
/* .name = */ llama_sampler_dist_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_dist_apply,
/* .reset = */ llama_sampler_dist_reset,
/* .clone = */ llama_sampler_dist_clone,
/* .free = */ llama_sampler_dist_free,
/* .apply_ggml = */ llama_sampler_dist_apply_ggml,
/* .accept_ggml = */ nullptr,
/* .set_input_ggml = */ llama_sampler_dist_set_input_ggml,
/* .init_ggml = */ llama_sampler_dist_init_ggml,
};
struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
@ -834,9 +948,13 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
return llama_sampler_init(
/* .iface = */ &llama_sampler_dist_i,
/* .ctx = */ new llama_sampler_dist {
/* .seed = */ seed,
/* .seed_cur = */ seed_cur,
/* .rng = */ std::mt19937(seed_cur),
/* .seed = */ seed,
/* .seed_cur = */ seed_cur,
/* .rng = */ std::mt19937(seed_cur),
/* .device = */ nullptr,
/* .inp_uniform = */ nullptr,
/* .inp_ctx = */ nullptr,
/* .inp_buf = */ nullptr,
}
);
}
@ -845,6 +963,9 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
struct llama_sampler_top_k {
const int32_t k;
// Only required for checking operation support and can be removed later.
ggml_backend_dev_t device;
};
static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl*/) {
@ -865,24 +986,60 @@ static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
delete (llama_sampler_top_k *) smpl->ctx;
}
static void llama_sampler_top_k_apply_ggml(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_ggml_data * ggml_data) {
auto * ctx_data = (llama_sampler_top_k *) smpl->ctx;
struct ggml_tensor * top_k = ggml_top_k(ctx, ggml_data->logits, ctx_data->k);
ggml_set_name(top_k, "top_k");
// top_k is a view of argsort - check if backend supports the underlying argsort operation
// by checking the source tensor (which is the argsort result)
if (ctx_data->device && top_k->src[0] && !ggml_backend_dev_supports_op(ctx_data->device, top_k->src[0])) {
fprintf(stderr, "Warning: backend does not support argsort operation required for top-k sampling\n");
fprintf(stderr, "CPU backend will be used instead which defeats the purpose of having backend samplers\n");
}
ggml_data->candidates = top_k;
struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, ggml_data->logits, 1, ggml_data->logits->ne[0]);
struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k);
ggml_set_name(top_k_rows, "top_k_rows");
ggml_data->logits = ggml_reshape_1d(ctx, top_k_rows, ctx_data->k);
ggml_build_forward_expand(gf, ggml_data->logits);
}
static void llama_sampler_top_k_init_ggml(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
auto * ctx_data = (llama_sampler_top_k *) smpl->ctx;
ctx_data->device = ggml_backend_buft_get_device(buft);
}
static struct llama_sampler_i llama_sampler_top_k_i = {
/* .name = */ llama_sampler_top_k_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_top_k_apply,
/* .reset = */ nullptr,
/* .clone = */ llama_sampler_top_k_clone,
/* .free = */ llama_sampler_top_k_free,
/* .apply_ggml = */ nullptr,
/* .accept_ggml = */ nullptr,
/* .set_input_ggml = */ nullptr,
/* .set_backend_context = */ nullptr,
/* .name = */ llama_sampler_top_k_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_top_k_apply,
/* .reset = */ nullptr,
/* .clone = */ llama_sampler_top_k_clone,
/* .free = */ llama_sampler_top_k_free,
/* .apply_ggml = */ llama_sampler_top_k_apply_ggml,
/* .accept_ggml = */ nullptr,
/* .set_input_ggml = */ nullptr,
/* .init_ggml = */ llama_sampler_top_k_init_ggml,
};
struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
return llama_sampler_init(
/* .iface = */ &llama_sampler_top_k_i,
/* .ctx = */ new llama_sampler_top_k {
/* .k = */ k,
/* .k = */ k,
/* .device = */ nullptr,
}
);
}
@ -894,6 +1051,9 @@ struct llama_sampler_top_p {
const size_t min_keep;
std::vector<llama_token_data> buf_sort;
// Only required for checking operation support and can be removed later.
ggml_backend_dev_t device;
};
static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) {
@ -964,17 +1124,87 @@ static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
delete (llama_sampler_top_p *) smpl->ctx;
}
static void llama_sampler_top_p_apply_ggml(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_ggml_data * ggml_data) {
auto * sctx = (llama_sampler_top_p *) smpl->ctx;
struct ggml_tensor * softmax = ggml_soft_max(ctx, ggml_data->logits);
ggml_set_name(softmax, "top_p_softmax");
// Get the sorted indices of the softmax probabilities in descending order.
struct ggml_tensor * sorted_idx = ggml_argsort(ctx, softmax, GGML_SORT_ORDER_DESC);
ggml_set_name(sorted_idx, "top_p_sorted_idx");
// Do the sorting via reshape + get_rows
struct ggml_tensor * softmax_reshaped = ggml_reshape_2d(ctx, softmax, 1, softmax->ne[0]);
ggml_set_name(softmax_reshaped, "top_p_softmax_reshaped");
struct ggml_tensor * sorted_probs = ggml_get_rows(ctx, softmax_reshaped, sorted_idx);
ggml_set_name(sorted_probs, "top_p_sorted_probs");
struct ggml_tensor * sorted_probs_reshaped = ggml_reshape_2d(ctx, sorted_probs, softmax->ne[0], 1);
ggml_set_name(sorted_probs_reshaped, "top_p_sorted_probs_reshaped");
// Compute Cumulative Distribution Function (CDF) by means of GGML_OP_CUMSUM.
struct ggml_tensor * sorted_cdf = ggml_cumsum(ctx, sorted_probs_reshaped);
ggml_set_name(sorted_cdf, "top_p_sorted_cdf");
// Invert CDF and add top-p value so that ggml_step yields 1 for values we want to keep
struct ggml_tensor * sorted_cdf_scaled = ggml_scale_bias(ctx, sorted_cdf, -1.0f, sctx->p);
ggml_set_name(sorted_cdf_scaled, "top_p_sorted_cdf_scaled");
struct ggml_tensor * sorted_mask = ggml_step(ctx, sorted_cdf_scaled);
ggml_set_name(sorted_mask, "top_p_sorted_mask");
// reverse sorting by argsort(argsort)
// cast to F32 since cuda only supports float inputs
struct ggml_tensor * reverse_argsort = ggml_argsort(ctx, ggml_cast(ctx, sorted_idx, GGML_TYPE_F32), GGML_SORT_ORDER_ASC);
ggml_set_name(reverse_argsort, "top_p_reverse_argsort");
// Do the sorting via reshape + get_rows
struct ggml_tensor * sorted_reshaped_mask = ggml_reshape_2d(ctx, sorted_mask, 1, sorted_mask->ne[0]);
ggml_set_name(sorted_reshaped_mask, "top_p_sorted_reshaped_mask");
struct ggml_tensor * reshaped_mask = ggml_get_rows(ctx, sorted_reshaped_mask, reverse_argsort);
ggml_set_name(reshaped_mask, "top_p_reshaped_mask");
struct ggml_tensor * mask = ggml_reshape_2d(ctx, reshaped_mask, sorted_mask->ne[0], 1);
ggml_set_name(mask, "top_p_mask");
// Use ggml_scale_bias (output = (a * s) + b) which in this case becomes:
// top_p_bias = (mask * 1e9f) - 1e9f.
// So entries in the mask that we want to discard will become -1e9f, and
// others will be 0 (meaning that will not effect the logits).
const float large_val = 1e9f;
struct ggml_tensor * top_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val);
ggml_set_name(top_p_bias, "top_p_bias");
ggml_data->logits = ggml_add(ctx, ggml_data->logits, top_p_bias);
ggml_set_name(ggml_data->logits, "top_p_logits");
ggml_build_forward_expand(gf, ggml_data->logits);
}
static void llama_sampler_top_p_init_ggml(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
auto * sctx = (llama_sampler_top_p *) smpl->ctx;
sctx->device = ggml_backend_buft_get_device(buft);
}
static struct llama_sampler_i llama_sampler_top_p_i = {
/* .name = */ llama_sampler_top_p_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_top_p_apply,
/* .reset = */ nullptr,
/* .clone = */ llama_sampler_top_p_clone,
/* .free = */ llama_sampler_top_p_free,
/* .apply_ggml = */ nullptr,
/* .accept_ggml = */ nullptr,
/* .set_input_ggml = */ nullptr,
/* .set_backend_context = */ nullptr,
/* .name = */ llama_sampler_top_p_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_top_p_apply,
/* .reset = */ nullptr,
/* .clone = */ llama_sampler_top_p_clone,
/* .free = */ llama_sampler_top_p_free,
/* .apply_ggml = */ llama_sampler_top_p_apply_ggml,
/* .accept_ggml = */ nullptr,
/* .set_input_ggml = */ nullptr,
/* .init_ggml = */ llama_sampler_top_p_init_ggml,
};
struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
@ -984,6 +1214,7 @@ struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
/* .p = */ p,
/* .min_keep = */ min_keep,
/* .buf_sort = */ {},
/* .device = */ nullptr,
}
);
}
@ -993,6 +1224,9 @@ struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
struct llama_sampler_min_p {
const float p;
const size_t min_keep;
// Only required for checking operation support and can be removed later.
ggml_backend_dev_t device;
};
static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl*/) {
@ -1062,17 +1296,67 @@ static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
delete (llama_sampler_min_p *) smpl->ctx;
}
static void llama_sampler_min_p_apply_ggml(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_ggml_data * ggml_data) {
auto * sctx = (llama_sampler_min_p *) smpl->ctx;
struct ggml_tensor * max_idx = ggml_argmax(ctx, ggml_data->logits);
ggml_set_name(max_idx, "max_idx");
struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, ggml_data->logits, 1, ggml_data->logits->ne[0]);
ggml_set_name(logits_rows, "logits_rows");
struct ggml_tensor * max_logit = ggml_get_rows(ctx, logits_rows, max_idx);
ggml_set_name(max_logit, "max_logit");
// Calculate the threshold value.
struct ggml_tensor * threshold = ggml_scale_bias(ctx, max_logit, 1.0f, logf(sctx->p));
ggml_set_name(threshold, "min_p_threshold");
// Subtract the threshold from logits.
struct ggml_tensor * sub = ggml_sub(ctx, ggml_data->logits, threshold);
// Create a mask where logits below the threshold are 0 (discard),
// and others are 1 (keep).
struct ggml_tensor * mask = ggml_step(ctx, sub);
ggml_set_name(mask, "min_p_mask");
// Use ggml_scale_bias (output = (a * s) + b) which in this case becomes:
// min_p_bias = (mask * 1e9f) - 1e9f.
// So entries in the mask that we want to discard will become -1e9f, and
// others will be 0 (meaning that will not effect the logits).
const float large_val = 1e9f;
struct ggml_tensor * min_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val);
ggml_set_name(min_p_bias, "min_p_bias");
// Add the min_p bias to the logits.
ggml_data->logits = ggml_add(ctx, ggml_data->logits, min_p_bias);
ggml_set_name(ggml_data->logits, "min_p_logits");
ggml_build_forward_expand(gf, ggml_data->logits);
}
static void llama_sampler_min_p_init_ggml(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
auto * sctx = (llama_sampler_min_p *) smpl->ctx;
sctx->device = ggml_backend_buft_get_device(buft);
}
static struct llama_sampler_i llama_sampler_min_p_i = {
/* .name = */ llama_sampler_min_p_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_min_p_apply,
/* .reset = */ nullptr,
/* .clone = */ llama_sampler_min_p_clone,
/* .free = */ llama_sampler_min_p_free,
/* .apply_ggml = */ nullptr,
/* .accept_ggml = */ nullptr,
/* .set_input_ggml = */ nullptr,
/* .set_backend_context = */ nullptr,
/* .name = */ llama_sampler_min_p_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_min_p_apply,
/* .reset = */ nullptr,
/* .clone = */ llama_sampler_min_p_clone,
/* .free = */ llama_sampler_min_p_free,
/* .apply_ggml = */ llama_sampler_min_p_apply_ggml,
/* .accept_ggml = */ nullptr,
/* .set_input_ggml = */ nullptr,
/* .init_ggml = */ llama_sampler_min_p_init_ggml,
};
struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
@ -1081,6 +1365,7 @@ struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
/* .ctx = */ new llama_sampler_min_p {
/* .p = */ p,
/* .min_keep = */ min_keep,
/* .device = */ nullptr,
}
);
}
@ -1166,16 +1451,16 @@ static void llama_sampler_typical_free(struct llama_sampler * smpl) {
}
static struct llama_sampler_i llama_sampler_typical_i = {
/* .name = */ llama_sampler_typical_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_typical_apply,
/* .reset = */ nullptr,
/* .clone = */ llama_sampler_typical_clone,
/* .free = */ llama_sampler_typical_free,
/* .apply_ggml = */ nullptr,
/* .accept_ggml = */ nullptr,
/* .set_input_ggml = */ nullptr,
/* .set_backend_context = */ nullptr,
/* .name = */ llama_sampler_typical_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_typical_apply,
/* .reset = */ nullptr,
/* .clone = */ llama_sampler_typical_clone,
/* .free = */ llama_sampler_typical_free,
/* .apply_ggml = */ nullptr,
/* .accept_ggml = */ nullptr,
/* .set_input_ggml = */ nullptr,
/* .init_ggml = */ nullptr,
};
struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
@ -1213,17 +1498,38 @@ static void llama_sampler_temp_free(struct llama_sampler * smpl) {
delete (llama_sampler_temp *) smpl->ctx;
}
static void llama_sampler_temp_apply_ggml(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_ggml_data * ggml_data) {
auto * ctx_data = (llama_sampler_temp *) smpl->ctx;
if (ctx_data->temp <= 0.0f) {
return;
}
struct ggml_tensor * scaled = ggml_scale(ctx, ggml_data->logits, 1.0f / ctx_data->temp);
ggml_set_name(scaled, "temp_scaled");
// Make sure the scaled tensor is contiguous for subsequent operations
ggml_data->logits = ggml_cont(ctx, scaled);
ggml_set_name(ggml_data->logits, "temp_scaled_logits");
ggml_build_forward_expand(gf, ggml_data->logits);
}
static struct llama_sampler_i llama_sampler_temp_i = {
/* .name = */ llama_sampler_temp_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_temp_apply,
/* .reset = */ nullptr,
/* .clone = */ llama_sampler_temp_clone,
/* .free = */ llama_sampler_temp_free,
/* .apply_ggml = */ nullptr,
/* .accept_ggml = */ nullptr,
/* .set_input_ggml = */ nullptr,
/* .set_backend_context = */ nullptr,
/* .name = */ llama_sampler_temp_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_temp_apply,
/* .reset = */ nullptr,
/* .clone = */ llama_sampler_temp_clone,
/* .free = */ llama_sampler_temp_free,
/* .apply_ggml = */ llama_sampler_temp_apply_ggml,
/* .accept_ggml = */ nullptr,
/* .set_input_ggml = */ nullptr,
/* .init_ggml = */ nullptr,
};
struct llama_sampler * llama_sampler_init_temp(float temp) {
@ -1328,16 +1634,16 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
}
static struct llama_sampler_i llama_sampler_temp_ext_i = {
/* .name = */ llama_sampler_temp_ext_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_temp_ext_apply,
/* .reset = */ nullptr,
/* .clone = */ llama_sampler_temp_ext_clone,
/* .free = */ llama_sampler_temp_ext_free,
/* .apply_ggml = */ nullptr,
/* .accept_ggml = */ nullptr,
/* .set_input_ggml = */ nullptr,
/* .set_backend_context = */ nullptr,
/* .name = */ llama_sampler_temp_ext_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_temp_ext_apply,
/* .reset = */ nullptr,
/* .clone = */ llama_sampler_temp_ext_clone,
/* .free = */ llama_sampler_temp_ext_free,
/* .apply_ggml = */ nullptr,
/* .accept_ggml = */ nullptr,
/* .set_input_ggml = */ nullptr,
/* .init_ggml = */ nullptr,
};
struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
@ -1426,16 +1732,16 @@ static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
}
static struct llama_sampler_i llama_sampler_xtc_i = {
/* .name = */ llama_sampler_xtc_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sample_xtc_apply,
/* .reset = */ llama_sampler_xtc_reset,
/* .clone = */ llama_sampler_xtc_clone,
/* .free = */ llama_sampler_xtc_free,
/* .apply_ggml = */ nullptr,
/* .accept_ggml = */ nullptr,
/* .set_input_ggml = */ nullptr,
/* .set_backend_context = */ nullptr,
/* .name = */ llama_sampler_xtc_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sample_xtc_apply,
/* .reset = */ llama_sampler_xtc_reset,
/* .clone = */ llama_sampler_xtc_clone,
/* .free = */ llama_sampler_xtc_free,
/* .apply_ggml = */ nullptr,
/* .accept_ggml = */ nullptr,
/* .set_input_ggml = */ nullptr,
/* .init_ggml = */ nullptr,
};
struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
@ -1538,16 +1844,16 @@ static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
}
static struct llama_sampler_i llama_sampler_mirostat_i = {
/* .name = */ llama_sampler_mirostat_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_mirostat_apply,
/* .reset = */ llama_sampler_mirostat_reset,
/* .clone = */ llama_sampler_mirostat_clone,
/* .free = */ llama_sampler_mirostat_free,
/* .apply_ggml = */ nullptr,
/* .accept_ggml = */ nullptr,
/* .set_input_ggml = */ nullptr,
/* .set_backend_context = */ nullptr,
/* .name = */ llama_sampler_mirostat_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_mirostat_apply,
/* .reset = */ llama_sampler_mirostat_reset,
/* .clone = */ llama_sampler_mirostat_clone,
/* .free = */ llama_sampler_mirostat_free,
/* .apply_ggml = */ nullptr,
/* .accept_ggml = */ nullptr,
/* .set_input_ggml = */ nullptr,
/* .init_ggml = */ nullptr,
};
struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
@ -1641,16 +1947,16 @@ static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) {
}
static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
/* .name = */ llama_sampler_mirostat_v2_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_mirostat_v2_apply,
/* .reset = */ llama_sampler_mirostat_v2_reset,
/* .clone = */ llama_sampler_mirostat_v2_clone,
/* .free = */ llama_sampler_mirostat_v2_free,
/* .apply_ggml = */ nullptr,
/* .accept_ggml = */ nullptr,
/* .set_input_ggml = */ nullptr,
/* .set_backend_context = */ nullptr,
/* .name = */ llama_sampler_mirostat_v2_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_mirostat_v2_apply,
/* .reset = */ llama_sampler_mirostat_v2_reset,
/* .clone = */ llama_sampler_mirostat_v2_clone,
/* .free = */ llama_sampler_mirostat_v2_free,
/* .apply_ggml = */ nullptr,
/* .accept_ggml = */ nullptr,
/* .set_input_ggml = */ nullptr,
/* .init_ggml = */ nullptr,
};
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
@ -1762,16 +2068,16 @@ static void llama_sampler_grammar_free(struct llama_sampler * smpl) {
}
static struct llama_sampler_i llama_sampler_grammar_i = {
/* .name = */ llama_sampler_grammar_name,
/* .accept = */ llama_sampler_grammar_accept_impl,
/* .apply = */ llama_sampler_grammar_apply,
/* .reset = */ llama_sampler_grammar_reset,
/* .clone = */ llama_sampler_grammar_clone,
/* .free = */ llama_sampler_grammar_free,
/* .apply_ggml = */ nullptr,
/* .accept_ggml = */ nullptr,
/* .set_input_ggml = */ nullptr,
/* .set_backend_context = */ nullptr,
/* .name = */ llama_sampler_grammar_name,
/* .accept = */ llama_sampler_grammar_accept_impl,
/* .apply = */ llama_sampler_grammar_apply,
/* .reset = */ llama_sampler_grammar_reset,
/* .clone = */ llama_sampler_grammar_clone,
/* .free = */ llama_sampler_grammar_free,
/* .apply_ggml = */ nullptr,
/* .accept_ggml = */ nullptr,
/* .set_input_ggml = */ nullptr,
/* .init_ggml = */ nullptr,
};
static struct llama_sampler * llama_sampler_init_grammar_impl(
@ -1973,16 +2279,16 @@ static void llama_sampler_penalties_free(struct llama_sampler * smpl) {
}
static struct llama_sampler_i llama_sampler_penalties_i = {
/* .name = */ llama_sampler_penalties_name,
/* .accept = */ llama_sampler_penalties_accept,
/* .apply = */ llama_sampler_penalties_apply,
/* .reset = */ llama_sampler_penalties_reset,
/* .clone = */ llama_sampler_penalties_clone,
/* .free = */ llama_sampler_penalties_free,
/* .apply_ggml = */ nullptr,
/* .accept_ggml = */ nullptr,
/* .set_input_ggml = */ nullptr,
/* .set_backend_context = */ nullptr,
/* .name = */ llama_sampler_penalties_name,
/* .accept = */ llama_sampler_penalties_accept,
/* .apply = */ llama_sampler_penalties_apply,
/* .reset = */ llama_sampler_penalties_reset,
/* .clone = */ llama_sampler_penalties_clone,
/* .free = */ llama_sampler_penalties_free,
/* .apply_ggml = */ nullptr,
/* .accept_ggml = */ nullptr,
/* .set_input_ggml = */ nullptr,
/* .init_ggml = */ nullptr,
};
struct llama_sampler * llama_sampler_init_penalties(
@ -2068,16 +2374,16 @@ static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) {
}
static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
/* .name = */ llama_sampler_top_n_sigma_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_top_n_sigma_apply,
/* .reset = */ nullptr,
/* .clone = */ llama_sampler_top_n_sigma_clone,
/* .free = */ llama_sampler_top_n_sigma_free,
/* .apply_ggml = */ nullptr,
/* .accept_ggml = */ nullptr,
/* .set_input_ggml = */ nullptr,
/* .set_backend_context = */ nullptr,
/* .name = */ llama_sampler_top_n_sigma_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_top_n_sigma_apply,
/* .reset = */ nullptr,
/* .clone = */ llama_sampler_top_n_sigma_clone,
/* .free = */ llama_sampler_top_n_sigma_free,
/* .apply_ggml = */ nullptr,
/* .accept_ggml = */ nullptr,
/* .set_input_ggml = */ nullptr,
/* .init_ggml = */ nullptr,
};
struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
@ -2402,16 +2708,16 @@ static void llama_sampler_dry_free(struct llama_sampler * smpl) {
}
static struct llama_sampler_i llama_sampler_dry_i = {
/* .name = */ llama_sampler_dry_name,
/* .accept = */ llama_sampler_dry_accept,
/* .apply = */ llama_sampler_dry_apply,
/* .reset = */ llama_sampler_dry_reset,
/* .clone = */ llama_sampler_dry_clone,
/* .free = */ llama_sampler_dry_free,
/* .apply_ggml = */ nullptr,
/* .accept_ggml = */ nullptr,
/* .set_input_ggml = */ nullptr,
/* .set_backend_context = */ nullptr,
/* .name = */ llama_sampler_dry_name,
/* .accept = */ llama_sampler_dry_accept,
/* .apply = */ llama_sampler_dry_apply,
/* .reset = */ llama_sampler_dry_reset,
/* .clone = */ llama_sampler_dry_clone,
/* .free = */ llama_sampler_dry_free,
/* .apply_ggml = */ nullptr,
/* .accept_ggml = */ nullptr,
/* .set_input_ggml = */ nullptr,
/* .init_ggml = */ nullptr,
};
struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t n_ctx_train, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
@ -2498,6 +2804,11 @@ struct llama_sampler_logit_bias {
const std::vector<llama_logit_bias> logit_bias;
std::vector<llama_logit_bias> to_search;
struct ggml_tensor * inp_logit_bias;
ggml_context_ptr inp_ctx;
ggml_backend_buffer_ptr inp_buf;
};
static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /*smpl*/) {
@ -2546,17 +2857,77 @@ static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) {
delete (llama_sampler_logit_bias *) smpl->ctx;
}
static void llama_sampler_logit_bias_apply_ggml(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_ggml_data * ggml_data) {
GGML_UNUSED(gf);
GGML_UNUSED(ctx);
auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
if (sctx->logit_bias.empty()) {
return;
}
// Add the sparse logit logit_bias to the logits
struct ggml_tensor * logit_biased = ggml_add_inplace(ctx, ggml_data->logits, sctx->inp_logit_bias);
ggml_build_forward_expand(gf, logit_biased);
}
static void llama_sampler_logit_bias_set_input_ggml(struct llama_sampler * smpl) {
auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
if (sctx->logit_bias.empty()) {
return;
}
GGML_ASSERT(sctx->inp_logit_bias != nullptr);
// Create a sparse logit_bias vector from the logit_bias entries.
std::vector<float> logit_bias_sparse(sctx->n_vocab, 0.0f);
for (const auto & lb : sctx->logit_bias) {
GGML_ASSERT(lb.token >= 0 && lb.token < (int32_t) sctx->n_vocab);
logit_bias_sparse[lb.token] = lb.bias;
}
ggml_backend_tensor_set(sctx->inp_logit_bias, logit_bias_sparse.data(), 0, ggml_nbytes(sctx->inp_logit_bias));
}
static void llama_sampler_logit_bias_init_ggml(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
if (sctx->logit_bias.empty()) {
return;
}
ggml_init_params params = {
/*.mem_size =*/ ggml_tensor_overhead(),
/*.mem_buffer =*/ nullptr,
/*.no_alloc =*/ true,
};
sctx->inp_ctx.reset(ggml_init(params));
sctx->inp_logit_bias = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, sctx->n_vocab);
ggml_set_name(sctx->inp_logit_bias, "logit_bias");
ggml_set_input(sctx->inp_logit_bias);
// Allocate all tensors from our context to the backend
sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
}
static struct llama_sampler_i llama_sampler_logit_bias_i = {
/* .name = */ llama_sampler_logit_bias_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_logit_bias_apply,
/* .reset = */ nullptr,
/* .clone = */ llama_sampler_logit_bias_clone,
/* .free = */ llama_sampler_logit_bias_free,
/* .apply_ggml = */ nullptr,
/* .accept_ggml = */ nullptr,
/* .set_input_ggml = */ nullptr,
/* .set_backend_context = */ nullptr,
/* .name = */ llama_sampler_logit_bias_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_logit_bias_apply,
/* .reset = */ nullptr,
/* .clone = */ llama_sampler_logit_bias_clone,
/* .free = */ llama_sampler_logit_bias_free,
/* .apply_ggml = */ llama_sampler_logit_bias_apply_ggml,
/* .accept_ggml = */ nullptr,
/* .set_input_ggml = */ llama_sampler_logit_bias_set_input_ggml,
/* .init_ggml = */ llama_sampler_logit_bias_init_ggml,
};
struct llama_sampler * llama_sampler_init_logit_bias(
@ -2566,9 +2937,12 @@ struct llama_sampler * llama_sampler_init_logit_bias(
return llama_sampler_init(
/* .iface = */ &llama_sampler_logit_bias_i,
/* .ctx = */ new llama_sampler_logit_bias {
/* .n_vocab = */ n_vocab,
/* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
/* .to_search = */ {},
/* .n_vocab = */ n_vocab,
/* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
/* .to_search = */ {},
/* .inp_logit_bias = */ nullptr,
/* .inp_ctx = */ nullptr,
/* .inp_buf = */ nullptr,
}
);
}
@ -2781,16 +3155,16 @@ static void llama_sampler_infill_free(struct llama_sampler * smpl) {
}
static struct llama_sampler_i llama_sampler_infill_i = {
/* .name = */ llama_sampler_infill_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_infill_apply,
/* .reset = */ nullptr,
/* .clone = */ llama_sampler_infill_clone,
/* .free = */ llama_sampler_infill_free,
/* .apply_ggml = */ nullptr,
/* .accept_ggml = */ nullptr,
/* .set_input_ggml = */ nullptr,
/* .set_backend_context = */ nullptr,
/* .name = */ llama_sampler_infill_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_infill_apply,
/* .reset = */ nullptr,
/* .clone = */ llama_sampler_infill_clone,
/* .free = */ llama_sampler_infill_free,
/* .apply_ggml = */ nullptr,
/* .accept_ggml = */ nullptr,
/* .set_input_ggml = */ nullptr,
/* .init_ggml = */ nullptr,
};
struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {

View File

@ -277,7 +277,7 @@ static void test_backend_greedy_sampling(const char * model_path) {
struct llama_sampler_chain_params backend_sampler_params = llama_sampler_chain_default_params();
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_sampler_params);
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_greedy());
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_greedy());
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain }};
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
@ -315,7 +315,7 @@ static void test_backend_top_k_sampling(const char * model_path) {
const int32_t k = 8;
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_top_k(k));
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_top_k(k));
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain }};
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
@ -363,12 +363,12 @@ static void test_backend_temp_sampling(const char * model_path) {
const float temp_0 = 0.8f;
struct llama_sampler_chain_params backend_chain_params_0 = llama_sampler_chain_default_params();
struct llama_sampler * backend_sampler_chain_0 = llama_sampler_chain_init(backend_chain_params_0);
llama_sampler_chain_add(backend_sampler_chain_0, llama_sampler_backend_init_temp(temp_0));
llama_sampler_chain_add(backend_sampler_chain_0, llama_sampler_init_temp(temp_0));
const float temp_1 = 0.1f;
struct llama_sampler_chain_params backend_chain_params_1 = llama_sampler_chain_default_params();
struct llama_sampler * backend_sampler_chain_1 = llama_sampler_chain_init(backend_chain_params_1);
llama_sampler_chain_add(backend_sampler_chain_1, llama_sampler_backend_init_temp(temp_1));
llama_sampler_chain_add(backend_sampler_chain_1, llama_sampler_init_temp(temp_1));
std::vector<llama_sampler_seq_config> backend_sampler_configs = {
{ 0, backend_sampler_chain_0 },
@ -430,7 +430,7 @@ static void test_backend_min_p_sampling(const char * model_path) {
const float p = 0.1;
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_min_p(p));
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_min_p(p, 0));
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain }};
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
@ -488,7 +488,7 @@ static void test_backend_top_p_sampling(const char * model_path) {
const float p = 0.9;
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_top_p(p));
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_top_p(p, 0));
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain }};
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
@ -541,12 +541,12 @@ static void test_backend_multi_sequence_sampling(const char * model_path) {
struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
struct llama_sampler * sampler_chain_0 = llama_sampler_chain_init(chain_params_0);
llama_sampler_chain_add(sampler_chain_0, llama_sampler_backend_init_greedy());
llama_sampler_chain_add(sampler_chain_0, llama_sampler_init_greedy());
struct llama_sampler_chain_params chain_params_1 = llama_sampler_chain_default_params();
struct llama_sampler * sampler_chain_1 = llama_sampler_chain_init(chain_params_1);
llama_sampler_chain_add(sampler_chain_1, llama_sampler_backend_init_temp(0.8f));
llama_sampler_chain_add(sampler_chain_1, llama_sampler_backend_init_greedy());
llama_sampler_chain_add(sampler_chain_1, llama_sampler_init_temp(0.8f));
llama_sampler_chain_add(sampler_chain_1, llama_sampler_init_greedy());
std::vector<llama_sampler_seq_config> backend_sampler_configs = {
{ 0, sampler_chain_0 },
@ -613,7 +613,7 @@ static void test_backend_dist_sampling(const char * model_path) {
const int32_t seed = 88;
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_dist(seed));
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_dist(seed));
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain }};
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
@ -642,7 +642,7 @@ static void test_backend_dist_sampling_and_cpu(const char * model_path) {
const int32_t seed = 88;
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_dist(seed));
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_dist(seed));
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain }};
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
@ -689,11 +689,11 @@ static void test_backend_logit_bias_sampling(const char * model_path) {
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_logit_bias(
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_logit_bias(
llama_vocab_n_tokens(test_ctx.vocab),
logit_bias.size(),
logit_bias.data()));
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_dist(88));
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_dist(88));
std::vector<llama_sampler_seq_config> backend_sampler_configs = {
{ seq_id, backend_sampler_chain },
@ -720,12 +720,12 @@ static void test_backend_mixed_sampling(const char * model_path) {
struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
struct llama_sampler * sampler_chain_0 = llama_sampler_chain_init(chain_params_0);
llama_sampler_chain_add(sampler_chain_0, llama_sampler_backend_init_dist(88));
llama_sampler_chain_add(sampler_chain_0, llama_sampler_init_dist(88));
int k = 40;
struct llama_sampler_chain_params chain_params_1 = llama_sampler_chain_default_params();
struct llama_sampler * sampler_chain_1 = llama_sampler_chain_init(chain_params_1);
llama_sampler_chain_add(sampler_chain_1, llama_sampler_backend_init_top_k(k));
llama_sampler_chain_add(sampler_chain_1, llama_sampler_init_top_k(k));
std::vector<llama_sampler_seq_config> backend_sampler_configs = {
{ 0, sampler_chain_0 },
@ -776,7 +776,7 @@ static void test_backend_set_sampler(const char * model_path) {
const int seq_id = 0;
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_dist(seed));
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_dist(seed));
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain }};
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
@ -822,8 +822,8 @@ static void test_backend_set_sampler(const char * model_path) {
// Set a new backend sampler for the sequence.
struct llama_sampler_chain_params new_backend_chain_params = llama_sampler_chain_default_params();
struct llama_sampler * new_backend_sampler_chain = llama_sampler_chain_init(new_backend_chain_params);
llama_sampler_chain_add(new_backend_sampler_chain, llama_sampler_backend_init_top_k(20));
llama_sampler_chain_add(new_backend_sampler_chain, llama_sampler_backend_init_dist(seed));
llama_sampler_chain_add(new_backend_sampler_chain, llama_sampler_init_top_k(20));
llama_sampler_chain_add(new_backend_sampler_chain, llama_sampler_init_dist(seed));
llama_set_backend_sampler(test_ctx.ctx, seq_id, new_backend_sampler_chain);
if (!test_ctx.decode_tokens(tokens2)) {
@ -841,7 +841,7 @@ static void test_backend_cpu_mixed_batch(const char * model_path) {
// Sequence 0 uses backend sampling
struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
struct llama_sampler * sampler_chain_0 = llama_sampler_chain_init(chain_params_0);
llama_sampler_chain_add(sampler_chain_0, llama_sampler_backend_init_dist(88));
llama_sampler_chain_add(sampler_chain_0, llama_sampler_init_dist(88));
std::vector<llama_sampler_seq_config> backend_sampler_configs = {
{ 0, sampler_chain_0 },
@ -912,7 +912,7 @@ static void test_backend_cpu_mixed_batch(const char * model_path) {
{
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
struct llama_sampler * sampler_chain= llama_sampler_chain_init(chain_params);
llama_sampler_chain_add(sampler_chain, llama_sampler_backend_init_dist(88));
llama_sampler_chain_add(sampler_chain, llama_sampler_init_dist(88));
llama_set_backend_sampler(test_ctx.ctx, 0, sampler_chain);
@ -937,7 +937,7 @@ static void test_backend_max_outputs(const char * model_path) {
const int32_t seed = 88;
llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_dist(seed));
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_init_dist(seed));
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain }};
if (!test_ctx.setup(model_path, backend_sampler_configs)) {