llama : cleanup + naming

This commit is contained in:
Georgi Gerganov 2025-11-29 22:37:07 +02:00
parent fbc8f49f3c
commit 9028ebfea8
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
5 changed files with 335 additions and 374 deletions

View File

@ -210,13 +210,6 @@ extern "C" {
bool sorted; // note: do not assume the data is sorted - always check this flag bool sorted; // note: do not assume the data is sorted - always check this flag
} llama_token_data_array; } llama_token_data_array;
struct llama_sampler_ggml_data {
struct ggml_tensor * logits;
struct ggml_tensor * probs;
struct ggml_tensor * sampled;
struct ggml_tensor * candidates;
};
typedef bool (*llama_progress_callback)(float progress, void * user_data); typedef bool (*llama_progress_callback)(float progress, void * user_data);
// Input data for llama_encode/llama_decode // Input data for llama_encode/llama_decode
@ -1181,11 +1174,16 @@ extern "C" {
// //
// llama_sampler_free(smpl); // llama_sampler_free(smpl);
// //
// TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU).
//
typedef void * llama_sampler_context_t; typedef void * llama_sampler_context_t;
struct llama_sampler_backend_data {
struct ggml_tensor * logits;
struct ggml_tensor * probs;
struct ggml_tensor * sampled;
struct ggml_tensor * candidates;
};
// user code can implement the interface below in order to create custom llama_sampler // user code can implement the interface below in order to create custom llama_sampler
struct llama_sampler_i { struct llama_sampler_i {
const char * (*name) (const struct llama_sampler * smpl); // can be NULL const char * (*name) (const struct llama_sampler * smpl); // can be NULL
@ -1195,25 +1193,28 @@ extern "C" {
struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL
void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL
void (*apply_ggml)( struct llama_sampler * smpl, // backend sampling interface
struct ggml_context * ctx, void (*backend_init)(struct llama_sampler * smpl, ggml_backend_buffer_type_t buft);
struct ggml_cgraph * gf,
struct llama_sampler_ggml_data * ggml_data);
void (*accept_ggml)( struct llama_sampler * smpl, void (*backend_accept)(
struct ggml_context * ctx, struct llama_sampler * smpl,
struct ggml_cgraph * gf, struct ggml_context * ctx,
struct ggml_tensor * selected_token); struct ggml_cgraph * gf,
struct ggml_tensor * selected_token);
void (*set_input_ggml)(struct llama_sampler * smpl); void (*backend_apply)(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_backend_data * ggml_data);
void (*init_ggml)(struct llama_sampler * smpl, void (*backend_set_input)(struct llama_sampler * smpl);
ggml_backend_buffer_type_t buft);
}; };
struct llama_sampler { struct llama_sampler {
const struct llama_sampler_i * iface; const struct llama_sampler_i * iface;
llama_sampler_context_t ctx;
llama_sampler_context_t ctx;
}; };
LLAMA_API void llama_set_backend_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl); LLAMA_API void llama_set_backend_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl);
@ -1228,17 +1229,6 @@ extern "C" {
// important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add) // important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add)
LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl); LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl);
LLAMA_API void llama_sampler_init_ggml (struct llama_sampler * smpl, ggml_backend_buffer_type_t buft);
LLAMA_API void llama_sampler_set_input_ggml(struct llama_sampler * smpl);
LLAMA_API void llama_sampler_apply_ggml (struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_ggml_data * ggml_data);
LLAMA_API void llama_sampler_accept_ggml (struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct ggml_tensor * selected_token);
// llama_sampler_chain // llama_sampler_chain
// a type of llama_sampler that can chain multiple samplers one after another // a type of llama_sampler that can chain multiple samplers one after another

View File

@ -465,8 +465,8 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) { void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) {
GGML_UNUSED(ubatch); GGML_UNUSED(ubatch);
for (const auto & [seq_id, sampler] : samplers) { for (const auto & [seq_id, sampler] : samplers) {
if (sampler->iface->set_input_ggml) { if (sampler->iface->backend_set_input) {
sampler->iface->set_input_ggml(sampler); sampler->iface->backend_set_input(sampler);
} }
} }
} }
@ -2088,8 +2088,9 @@ void llm_graph_context::build_sampling() const {
const int32_t row_idx = it->second; const int32_t row_idx = it->second;
// Allow GPU sampler to create input tensors by implementing init_ggml. // Allow GPU sampler to create input tensors by implementing init_ggml.
if (sampler->iface->init_ggml != nullptr) { // TODO: this should not be done here
sampler->iface->init_ggml(sampler, buft); if (sampler->iface->backend_init != nullptr) {
sampler->iface->backend_init(sampler, buft);
} }
active_samplers[seq_id] = sampler; active_samplers[seq_id] = sampler;
@ -2097,33 +2098,34 @@ void llm_graph_context::build_sampling() const {
ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, n_vocab, row_idx * logits_t->nb[1]); ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, n_vocab, row_idx * logits_t->nb[1]);
ggml_format_name(logits_seq, "logits_seq_%d", seq_id); ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
struct llama_sampler_ggml_data ggml_data = { struct llama_sampler_backend_data data = {
/*.logits =*/ logits_seq, /*.logits =*/ logits_seq,
/*.probs =*/ nullptr, /*.probs =*/ nullptr,
/*.sampled =*/ nullptr, /*.sampled =*/ nullptr,
/*.candidates =*/ nullptr, /*.candidates =*/ nullptr,
}; };
llama_sampler_apply_ggml(sampler, ctx0, gf, &ggml_data); assert(sampler->iface->backend_apply);
sampler->iface->backend_apply(sampler, ctx0, gf, &data);
if (ggml_data.sampled != nullptr) { if (data.sampled != nullptr) {
res->t_sampled[seq_id] = ggml_data.sampled; res->t_sampled[seq_id] = data.sampled;
ggml_build_forward_expand(gf, ggml_data.sampled); ggml_build_forward_expand(gf, data.sampled);
} }
if (ggml_data.probs != nullptr) { if (data.probs != nullptr) {
res->t_sampled_probs[seq_id] = ggml_data.probs; res->t_sampled_probs[seq_id] = data.probs;
ggml_build_forward_expand(gf, ggml_data.probs); ggml_build_forward_expand(gf, data.probs);
} }
if (ggml_data.logits != logits_seq) { if (data.logits != logits_seq) {
res->t_sampled_logits[seq_id] = ggml_data.logits; res->t_sampled_logits[seq_id] = data.logits;
ggml_build_forward_expand(gf, res->t_sampled_logits[seq_id]); ggml_build_forward_expand(gf, res->t_sampled_logits[seq_id]);
} }
if (ggml_data.candidates != nullptr) { if (data.candidates != nullptr) {
res->t_candidates[seq_id] = ggml_data.candidates; res->t_candidates[seq_id] = data.candidates;
ggml_build_forward_expand(gf, ggml_data.candidates); ggml_build_forward_expand(gf, data.candidates);
} }
} }

View File

@ -348,7 +348,9 @@ static uint32_t get_rng_seed(uint32_t seed) {
// llama_sampler API // llama_sampler API
struct llama_sampler * llama_sampler_init(const struct llama_sampler_i * iface, llama_sampler_context_t ctx) { struct llama_sampler * llama_sampler_init(
const struct llama_sampler_i * iface,
llama_sampler_context_t ctx) {
return new llama_sampler { return new llama_sampler {
/* .iface = */ iface, /* .iface = */ iface,
/* .ctx = */ ctx, /* .ctx = */ ctx,
@ -374,39 +376,6 @@ void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_ar
smpl->iface->apply(smpl, cur_p); smpl->iface->apply(smpl, cur_p);
} }
void llama_sampler_apply_ggml(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_ggml_data * ggml_data) {
GGML_ASSERT(smpl->iface->apply_ggml);
smpl->iface->apply_ggml(smpl, ctx, gf, ggml_data);
}
void llama_sampler_accept_ggml(
struct llama_sampler * smpl,
ggml_context * ctx,
ggml_cgraph * gf,
struct ggml_tensor * selected_token) {
if (smpl->iface->accept_ggml) {
smpl->iface->accept_ggml(smpl, ctx, gf, selected_token);
}
}
void llama_sampler_set_input_ggml(struct llama_sampler * smpl) {
if (smpl->iface->set_input_ggml) {
smpl->iface->set_input_ggml(smpl);
}
}
void llama_sampler_init_ggml(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
if (smpl->iface->init_ggml) {
smpl->iface->init_ggml(smpl, buft);
}
}
void llama_sampler_reset(struct llama_sampler * smpl) { void llama_sampler_reset(struct llama_sampler * smpl) {
if (smpl->iface->reset) { if (smpl->iface->reset) {
smpl->iface->reset(smpl); smpl->iface->reset(smpl);
@ -523,10 +492,10 @@ static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_d
time_meas tm(chain->t_sample_us, chain->params.no_perf); time_meas tm(chain->t_sample_us, chain->params.no_perf);
for (auto * smpl : chain->samplers) { for (auto * smpl : chain->samplers) {
// Skip GPU samplers - they have apply_ggml but no apply
if (smpl->iface->apply == nullptr) { if (smpl->iface->apply == nullptr) {
continue; continue;
} }
llama_sampler_apply(smpl, cur_p); llama_sampler_apply(smpl, cur_p);
} }
} }
@ -561,21 +530,19 @@ static void llama_sampler_chain_free(struct llama_sampler * smpl) {
delete chain; delete chain;
} }
static void llama_sampler_chain_apply_ggml( static void llama_sampler_chain_backend_init(
struct llama_sampler * smpl, struct llama_sampler * smpl,
struct ggml_context * ctx, ggml_backend_buffer_type_t buft) {
struct ggml_cgraph * gf,
struct llama_sampler_ggml_data * ggml_data) {
auto * chain = (llama_sampler_chain *) smpl->ctx; auto * chain = (llama_sampler_chain *) smpl->ctx;
for (auto * smpl : chain->samplers) { for (auto * smpl : chain->samplers) {
if (smpl->iface->apply_ggml) { if (smpl->iface->backend_init) {
smpl->iface->apply_ggml(smpl, ctx, gf, ggml_data); smpl->iface->backend_init(smpl,buft);
} }
} }
} }
static void llama_sampler_chain_accept_ggml( static void llama_sampler_chain_backend_accept(
struct llama_sampler * smpl, struct llama_sampler * smpl,
ggml_context * ctx, ggml_context * ctx,
ggml_cgraph * gf, ggml_cgraph * gf,
@ -583,45 +550,47 @@ static void llama_sampler_chain_accept_ggml(
auto * chain = (llama_sampler_chain *) smpl->ctx; auto * chain = (llama_sampler_chain *) smpl->ctx;
for (auto * smpl : chain->samplers) { for (auto * smpl : chain->samplers) {
if (smpl->iface->accept_ggml) { if (smpl->iface->backend_accept) {
smpl->iface->accept_ggml(smpl, ctx, gf, selected_token); smpl->iface->backend_accept(smpl, ctx, gf, selected_token);
} }
} }
} }
static void llama_sampler_chain_set_input_ggml(struct llama_sampler * smpl) { static void llama_sampler_chain_backend_apply(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_backend_data * data) {
auto * chain = (llama_sampler_chain *) smpl->ctx; auto * chain = (llama_sampler_chain *) smpl->ctx;
for (auto * smpl : chain->samplers) { for (auto * smpl : chain->samplers) {
if (smpl->iface->set_input_ggml) { if (smpl->iface->backend_apply) {
smpl->iface->set_input_ggml(smpl); smpl->iface->backend_apply(smpl, ctx, gf, data);
} }
} }
} }
static void llama_sampler_chain_set_backend_context( static void llama_sampler_chain_backend_set_input(struct llama_sampler * smpl) {
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
auto * chain = (llama_sampler_chain *) smpl->ctx; auto * chain = (llama_sampler_chain *) smpl->ctx;
for (auto * smpl : chain->samplers) { for (auto * smpl : chain->samplers) {
if (smpl->iface->init_ggml) { if (smpl->iface->backend_set_input) {
smpl->iface->init_ggml(smpl,buft); smpl->iface->backend_set_input(smpl);
} }
} }
} }
static struct llama_sampler_i llama_sampler_chain_i = { static struct llama_sampler_i llama_sampler_chain_i = {
/* .name = */ llama_sampler_chain_name, /* .name = */ llama_sampler_chain_name,
/* .accept = */ llama_sampler_chain_accept, /* .accept = */ llama_sampler_chain_accept,
/* .apply = */ llama_sampler_chain_apply, /* .apply = */ llama_sampler_chain_apply,
/* .reset = */ llama_sampler_chain_reset, /* .reset = */ llama_sampler_chain_reset,
/* .clone = */ llama_sampler_chain_clone, /* .clone = */ llama_sampler_chain_clone,
/* .free = */ llama_sampler_chain_free, /* .free = */ llama_sampler_chain_free,
/* .apply_ggml = */ llama_sampler_chain_apply_ggml, /* .backend_init = */ llama_sampler_chain_backend_init,
/* .accept_ggml = */ llama_sampler_chain_accept_ggml, /* .backend_accept = */ llama_sampler_chain_backend_accept,
/* .set_input_ggml = */ llama_sampler_chain_set_input_ggml, /* .backend_apply = */ llama_sampler_chain_backend_apply,
/* .init_ggml = */ llama_sampler_chain_set_backend_context, /* .backend_set_input = */ llama_sampler_chain_backend_set_input,
}; };
struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) { struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
@ -689,29 +658,29 @@ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_to
} }
} }
static void llama_sampler_greedy_apply_ggml( static void llama_sampler_greedy_backend_apply(
struct llama_sampler * smpl, struct llama_sampler * smpl,
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_cgraph * gf, struct ggml_cgraph * gf,
struct llama_sampler_ggml_data * ggml_data) { struct llama_sampler_backend_data * data) {
GGML_UNUSED(gf); GGML_UNUSED(gf);
GGML_UNUSED(smpl); GGML_UNUSED(smpl);
struct ggml_tensor * argmax_result = ggml_argmax(ctx, ggml_data->logits); struct ggml_tensor * argmax_result = ggml_argmax(ctx, data->logits);
ggml_set_name(argmax_result, "argmax_result"); ggml_set_name(argmax_result, "argmax_result");
ggml_data->sampled = argmax_result; data->sampled = argmax_result;
} }
static struct llama_sampler_i llama_sampler_greedy_i = { static struct llama_sampler_i llama_sampler_greedy_i = {
/* .name = */ llama_sampler_greedy_name, /* .name = */ llama_sampler_greedy_name,
/* .accept = */ nullptr, /* .accept = */ nullptr,
/* .apply = */ llama_sampler_greedy_apply, /* .apply = */ llama_sampler_greedy_apply,
/* .reset = */ nullptr, /* .reset = */ nullptr,
/* .clone = */ nullptr, /* .clone = */ nullptr,
/* .free = */ nullptr, /* .free = */ nullptr,
/* .apply_ggml = */ llama_sampler_greedy_apply_ggml, /* .backend_init = */ nullptr,
/* .accept_ggml = */ nullptr, /* .backend_accept = */ nullptr,
/* .set_input_ggml = */ nullptr, /* .backend_apply = */ llama_sampler_greedy_backend_apply,
/* .init_ggml = */ nullptr, /* .backend_set_input = */ nullptr,
}; };
struct llama_sampler * llama_sampler_init_greedy() { struct llama_sampler * llama_sampler_init_greedy() {
@ -838,15 +807,24 @@ static void llama_sampler_dist_free(struct llama_sampler * smpl) {
delete (llama_sampler_dist *) smpl->ctx; delete (llama_sampler_dist *) smpl->ctx;
} }
static void llama_sampler_dist_apply_ggml( static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) {
struct llama_sampler * smpl, auto * sctx = (llama_sampler_dist *) smpl->ctx;
struct ggml_context * ctx, GGML_ASSERT(sctx->inp_uniform != nullptr);
struct ggml_cgraph * gf,
struct llama_sampler_ggml_data * ggml_data) { std::uniform_real_distribution<float> dist(0.0f, 1.0f);
const float rnd = dist(sctx->rng);
ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float));
}
static void llama_sampler_dist_backend_apply(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_backend_data * data) {
GGML_UNUSED(gf); GGML_UNUSED(gf);
auto * sctx = (llama_sampler_dist *) smpl->ctx; auto * sctx = (llama_sampler_dist *) smpl->ctx;
struct ggml_tensor * probs = ggml_soft_max(ctx, ggml_data->logits); struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits);
ggml_set_name(probs, "dist_probs"); ggml_set_name(probs, "dist_probs");
struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs); struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs);
@ -883,8 +861,8 @@ static void llama_sampler_dist_apply_ggml(
// Map back to original vocab ids if a candidates tensor is available. // Map back to original vocab ids if a candidates tensor is available.
struct ggml_tensor * sampled_token = idx; struct ggml_tensor * sampled_token = idx;
if (ggml_data->candidates != nullptr) { if (data->candidates != nullptr) {
struct ggml_tensor * candidates = ggml_data->candidates; struct ggml_tensor * candidates = data->candidates;
struct ggml_tensor * candidates_reshaped = ggml_view_2d(ctx, candidates, 1, ggml_nelements(candidates), struct ggml_tensor * candidates_reshaped = ggml_view_2d(ctx, candidates, 1, ggml_nelements(candidates),
ggml_type_size(candidates->type), 0); ggml_type_size(candidates->type), 0);
@ -893,19 +871,10 @@ static void llama_sampler_dist_apply_ggml(
} }
ggml_set_output(sampled_token); ggml_set_output(sampled_token);
ggml_data->sampled = sampled_token; data->sampled = sampled_token;
} }
static void llama_sampler_dist_set_input_ggml(struct llama_sampler * smpl) { static void llama_sampler_dist_backend_init(
auto * sctx = (llama_sampler_dist *) smpl->ctx;
GGML_ASSERT(sctx->inp_uniform != nullptr);
std::uniform_real_distribution<float> dist(0.0f, 1.0f);
const float rnd = dist(sctx->rng);
ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float));
}
static void llama_sampler_dist_init_ggml(
struct llama_sampler * smpl, struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) { ggml_backend_buffer_type_t buft) {
auto * sctx = (llama_sampler_dist *) smpl->ctx; auto * sctx = (llama_sampler_dist *) smpl->ctx;
@ -921,7 +890,7 @@ static void llama_sampler_dist_init_ggml(
sctx->inp_ctx.reset(ggml_init(params)); sctx->inp_ctx.reset(ggml_init(params));
// Create the uniform random scalar input tensor. This will be set by // Create the uniform random scalar input tensor. This will be set by
// llama_sampler_dist_set_input_ggml after this graph is built. // llama_sampler_dist_backend_set_input after this graph is built.
sctx->inp_uniform = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1); sctx->inp_uniform = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1);
ggml_set_name(sctx->inp_uniform, "uniform"); ggml_set_name(sctx->inp_uniform, "uniform");
ggml_set_input(sctx->inp_uniform); ggml_set_input(sctx->inp_uniform);
@ -931,16 +900,16 @@ static void llama_sampler_dist_init_ggml(
} }
static struct llama_sampler_i llama_sampler_dist_i = { static struct llama_sampler_i llama_sampler_dist_i = {
/* .name = */ llama_sampler_dist_name, /* .name = */ llama_sampler_dist_name,
/* .accept = */ nullptr, /* .accept = */ nullptr,
/* .apply = */ llama_sampler_dist_apply, /* .apply = */ llama_sampler_dist_apply,
/* .reset = */ llama_sampler_dist_reset, /* .reset = */ llama_sampler_dist_reset,
/* .clone = */ llama_sampler_dist_clone, /* .clone = */ llama_sampler_dist_clone,
/* .free = */ llama_sampler_dist_free, /* .free = */ llama_sampler_dist_free,
/* .apply_ggml = */ llama_sampler_dist_apply_ggml, /* .backend_init = */ llama_sampler_dist_backend_init,
/* .accept_ggml = */ nullptr, /* .backend_accept = */ nullptr,
/* .set_input_ggml = */ llama_sampler_dist_set_input_ggml, /* .backend_apply = */ llama_sampler_dist_backend_apply,
/* .init_ggml = */ llama_sampler_dist_init_ggml, /* .backend_set_input = */ llama_sampler_dist_backend_set_input,
}; };
struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
@ -986,15 +955,22 @@ static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
delete (llama_sampler_top_k *) smpl->ctx; delete (llama_sampler_top_k *) smpl->ctx;
} }
static void llama_sampler_top_k_apply_ggml( static void llama_sampler_top_k_backend_init(
struct llama_sampler * smpl, struct llama_sampler * smpl,
struct ggml_context * ctx, ggml_backend_buffer_type_t buft) {
struct ggml_cgraph * gf, auto * ctx_data = (llama_sampler_top_k *) smpl->ctx;
struct llama_sampler_ggml_data * ggml_data) { ctx_data->device = ggml_backend_buft_get_device(buft);
}
static void llama_sampler_top_k_backend_apply(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_backend_data * data) {
auto * ctx_data = (llama_sampler_top_k *) smpl->ctx; auto * ctx_data = (llama_sampler_top_k *) smpl->ctx;
struct ggml_tensor * top_k = ggml_top_k(ctx, ggml_data->logits, ctx_data->k); struct ggml_tensor * top_k = ggml_top_k(ctx, data->logits, ctx_data->k);
ggml_set_name(top_k, "top_k"); ggml_set_name(top_k, "top_k");
// top_k is a view of argsort - check if backend supports the underlying argsort operation // top_k is a view of argsort - check if backend supports the underlying argsort operation
@ -1004,34 +980,27 @@ static void llama_sampler_top_k_apply_ggml(
fprintf(stderr, "CPU backend will be used instead which defeats the purpose of having backend samplers\n"); fprintf(stderr, "CPU backend will be used instead which defeats the purpose of having backend samplers\n");
} }
ggml_data->candidates = top_k; data->candidates = top_k;
struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, ggml_data->logits, 1, ggml_data->logits->ne[0]); struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k); struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k);
ggml_set_name(top_k_rows, "top_k_rows"); ggml_set_name(top_k_rows, "top_k_rows");
ggml_data->logits = ggml_reshape_1d(ctx, top_k_rows, ctx_data->k); data->logits = ggml_reshape_1d(ctx, top_k_rows, ctx_data->k);
ggml_build_forward_expand(gf, ggml_data->logits); ggml_build_forward_expand(gf, data->logits);
}
static void llama_sampler_top_k_init_ggml(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
auto * ctx_data = (llama_sampler_top_k *) smpl->ctx;
ctx_data->device = ggml_backend_buft_get_device(buft);
} }
static struct llama_sampler_i llama_sampler_top_k_i = { static struct llama_sampler_i llama_sampler_top_k_i = {
/* .name = */ llama_sampler_top_k_name, /* .name = */ llama_sampler_top_k_name,
/* .accept = */ nullptr, /* .accept = */ nullptr,
/* .apply = */ llama_sampler_top_k_apply, /* .apply = */ llama_sampler_top_k_apply,
/* .reset = */ nullptr, /* .reset = */ nullptr,
/* .clone = */ llama_sampler_top_k_clone, /* .clone = */ llama_sampler_top_k_clone,
/* .free = */ llama_sampler_top_k_free, /* .free = */ llama_sampler_top_k_free,
/* .apply_ggml = */ llama_sampler_top_k_apply_ggml, /* .backend_init = */ llama_sampler_top_k_backend_init,
/* .accept_ggml = */ nullptr, /* .backend_accept = */ nullptr,
/* .set_input_ggml = */ nullptr, /* .backend_apply = */ llama_sampler_top_k_backend_apply,
/* .init_ggml = */ llama_sampler_top_k_init_ggml, /* .backend_set_input = */ nullptr,
}; };
struct llama_sampler * llama_sampler_init_top_k(int32_t k) { struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
@ -1124,14 +1093,21 @@ static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
delete (llama_sampler_top_p *) smpl->ctx; delete (llama_sampler_top_p *) smpl->ctx;
} }
static void llama_sampler_top_p_apply_ggml( static void llama_sampler_top_p_backend_init(
struct llama_sampler * smpl, struct llama_sampler * smpl,
struct ggml_context * ctx, ggml_backend_buffer_type_t buft) {
struct ggml_cgraph * gf, auto * sctx = (llama_sampler_top_p *) smpl->ctx;
struct llama_sampler_ggml_data * ggml_data) { sctx->device = ggml_backend_buft_get_device(buft);
}
static void llama_sampler_top_p_backend_apply(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_backend_data * data) {
auto * sctx = (llama_sampler_top_p *) smpl->ctx; auto * sctx = (llama_sampler_top_p *) smpl->ctx;
struct ggml_tensor * softmax = ggml_soft_max(ctx, ggml_data->logits); struct ggml_tensor * softmax = ggml_soft_max(ctx, data->logits);
ggml_set_name(softmax, "top_p_softmax"); ggml_set_name(softmax, "top_p_softmax");
// Get the sorted indices of the softmax probabilities in descending order. // Get the sorted indices of the softmax probabilities in descending order.
@ -1181,30 +1157,23 @@ static void llama_sampler_top_p_apply_ggml(
struct ggml_tensor * top_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val); struct ggml_tensor * top_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val);
ggml_set_name(top_p_bias, "top_p_bias"); ggml_set_name(top_p_bias, "top_p_bias");
ggml_data->logits = ggml_add(ctx, ggml_data->logits, top_p_bias); data->logits = ggml_add(ctx, data->logits, top_p_bias);
ggml_set_name(ggml_data->logits, "top_p_logits"); ggml_set_name(data->logits, "top_p_logits");
ggml_build_forward_expand(gf, ggml_data->logits); ggml_build_forward_expand(gf, data->logits);
}
static void llama_sampler_top_p_init_ggml(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
auto * sctx = (llama_sampler_top_p *) smpl->ctx;
sctx->device = ggml_backend_buft_get_device(buft);
} }
static struct llama_sampler_i llama_sampler_top_p_i = { static struct llama_sampler_i llama_sampler_top_p_i = {
/* .name = */ llama_sampler_top_p_name, /* .name = */ llama_sampler_top_p_name,
/* .accept = */ nullptr, /* .accept = */ nullptr,
/* .apply = */ llama_sampler_top_p_apply, /* .apply = */ llama_sampler_top_p_apply,
/* .reset = */ nullptr, /* .reset = */ nullptr,
/* .clone = */ llama_sampler_top_p_clone, /* .clone = */ llama_sampler_top_p_clone,
/* .free = */ llama_sampler_top_p_free, /* .free = */ llama_sampler_top_p_free,
/* .apply_ggml = */ llama_sampler_top_p_apply_ggml, /* .backend_init = */ llama_sampler_top_p_backend_init,
/* .accept_ggml = */ nullptr, /* .backend_accept = */ nullptr,
/* .set_input_ggml = */ nullptr, /* .backend_apply = */ llama_sampler_top_p_backend_apply,
/* .init_ggml = */ llama_sampler_top_p_init_ggml, /* .backend_set_input = */ nullptr,
}; };
struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) { struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
@ -1296,17 +1265,24 @@ static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
delete (llama_sampler_min_p *) smpl->ctx; delete (llama_sampler_min_p *) smpl->ctx;
} }
static void llama_sampler_min_p_apply_ggml( static void llama_sampler_min_p_backend_init(
struct llama_sampler * smpl, struct llama_sampler * smpl,
struct ggml_context * ctx, ggml_backend_buffer_type_t buft) {
struct ggml_cgraph * gf, auto * sctx = (llama_sampler_min_p *) smpl->ctx;
struct llama_sampler_ggml_data * ggml_data) { sctx->device = ggml_backend_buft_get_device(buft);
}
static void llama_sampler_min_p_backend_apply(
struct llama_sampler * smpl,
struct ggml_context * ctx,
struct ggml_cgraph * gf,
struct llama_sampler_backend_data * data) {
auto * sctx = (llama_sampler_min_p *) smpl->ctx; auto * sctx = (llama_sampler_min_p *) smpl->ctx;
struct ggml_tensor * max_idx = ggml_argmax(ctx, ggml_data->logits); struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits);
ggml_set_name(max_idx, "max_idx"); ggml_set_name(max_idx, "max_idx");
struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, ggml_data->logits, 1, ggml_data->logits->ne[0]); struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
ggml_set_name(logits_rows, "logits_rows"); ggml_set_name(logits_rows, "logits_rows");
struct ggml_tensor * max_logit = ggml_get_rows(ctx, logits_rows, max_idx); struct ggml_tensor * max_logit = ggml_get_rows(ctx, logits_rows, max_idx);
@ -1317,7 +1293,7 @@ static void llama_sampler_min_p_apply_ggml(
ggml_set_name(threshold, "min_p_threshold"); ggml_set_name(threshold, "min_p_threshold");
// Subtract the threshold from logits. // Subtract the threshold from logits.
struct ggml_tensor * sub = ggml_sub(ctx, ggml_data->logits, threshold); struct ggml_tensor * sub = ggml_sub(ctx, data->logits, threshold);
// Create a mask where logits below the threshold are 0 (discard), // Create a mask where logits below the threshold are 0 (discard),
// and others are 1 (keep). // and others are 1 (keep).
@ -1333,30 +1309,23 @@ static void llama_sampler_min_p_apply_ggml(
ggml_set_name(min_p_bias, "min_p_bias"); ggml_set_name(min_p_bias, "min_p_bias");
// Add the min_p bias to the logits. // Add the min_p bias to the logits.
ggml_data->logits = ggml_add(ctx, ggml_data->logits, min_p_bias); data->logits = ggml_add(ctx, data->logits, min_p_bias);
ggml_set_name(ggml_data->logits, "min_p_logits"); ggml_set_name(data->logits, "min_p_logits");
ggml_build_forward_expand(gf, ggml_data->logits); ggml_build_forward_expand(gf, data->logits);
}
static void llama_sampler_min_p_init_ggml(
struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) {
auto * sctx = (llama_sampler_min_p *) smpl->ctx;
sctx->device = ggml_backend_buft_get_device(buft);
} }
static struct llama_sampler_i llama_sampler_min_p_i = { static struct llama_sampler_i llama_sampler_min_p_i = {
/* .name = */ llama_sampler_min_p_name, /* .name = */ llama_sampler_min_p_name,
/* .accept = */ nullptr, /* .accept = */ nullptr,
/* .apply = */ llama_sampler_min_p_apply, /* .apply = */ llama_sampler_min_p_apply,
/* .reset = */ nullptr, /* .reset = */ nullptr,
/* .clone = */ llama_sampler_min_p_clone, /* .clone = */ llama_sampler_min_p_clone,
/* .free = */ llama_sampler_min_p_free, /* .free = */ llama_sampler_min_p_free,
/* .apply_ggml = */ llama_sampler_min_p_apply_ggml, /* .backend_init = */ llama_sampler_min_p_backend_init,
/* .accept_ggml = */ nullptr, /* .backend_accept = */ nullptr,
/* .set_input_ggml = */ nullptr, /* .backend_apply = */ llama_sampler_min_p_backend_apply,
/* .init_ggml = */ llama_sampler_min_p_init_ggml, /* .backend_set_input = */ nullptr,
}; };
struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) { struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
@ -1451,16 +1420,16 @@ static void llama_sampler_typical_free(struct llama_sampler * smpl) {
} }
static struct llama_sampler_i llama_sampler_typical_i = { static struct llama_sampler_i llama_sampler_typical_i = {
/* .name = */ llama_sampler_typical_name, /* .name = */ llama_sampler_typical_name,
/* .accept = */ nullptr, /* .accept = */ nullptr,
/* .apply = */ llama_sampler_typical_apply, /* .apply = */ llama_sampler_typical_apply,
/* .reset = */ nullptr, /* .reset = */ nullptr,
/* .clone = */ llama_sampler_typical_clone, /* .clone = */ llama_sampler_typical_clone,
/* .free = */ llama_sampler_typical_free, /* .free = */ llama_sampler_typical_free,
/* .apply_ggml = */ nullptr, /* .backend_init = */ nullptr,
/* .accept_ggml = */ nullptr, /* .backend_accept = */ nullptr,
/* .set_input_ggml = */ nullptr, /* .backend_apply = */ nullptr,
/* .init_ggml = */ nullptr, /* .backend_set_input = */ nullptr,
}; };
struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) { struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
@ -1498,38 +1467,38 @@ static void llama_sampler_temp_free(struct llama_sampler * smpl) {
delete (llama_sampler_temp *) smpl->ctx; delete (llama_sampler_temp *) smpl->ctx;
} }
static void llama_sampler_temp_apply_ggml( static void llama_sampler_temp_backend_apply(
struct llama_sampler * smpl, struct llama_sampler * smpl,
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_cgraph * gf, struct ggml_cgraph * gf,
struct llama_sampler_ggml_data * ggml_data) { struct llama_sampler_backend_data * data) {
auto * ctx_data = (llama_sampler_temp *) smpl->ctx; auto * ctx_data = (llama_sampler_temp *) smpl->ctx;
if (ctx_data->temp <= 0.0f) { if (ctx_data->temp <= 0.0f) {
return; return;
} }
struct ggml_tensor * scaled = ggml_scale(ctx, ggml_data->logits, 1.0f / ctx_data->temp); struct ggml_tensor * scaled = ggml_scale(ctx, data->logits, 1.0f / ctx_data->temp);
ggml_set_name(scaled, "temp_scaled"); ggml_set_name(scaled, "temp_scaled");
// Make sure the scaled tensor is contiguous for subsequent operations // Make sure the scaled tensor is contiguous for subsequent operations
ggml_data->logits = ggml_cont(ctx, scaled); data->logits = ggml_cont(ctx, scaled);
ggml_set_name(ggml_data->logits, "temp_scaled_logits"); ggml_set_name(data->logits, "temp_scaled_logits");
ggml_build_forward_expand(gf, ggml_data->logits); ggml_build_forward_expand(gf, data->logits);
} }
static struct llama_sampler_i llama_sampler_temp_i = { static struct llama_sampler_i llama_sampler_temp_i = {
/* .name = */ llama_sampler_temp_name, /* .name = */ llama_sampler_temp_name,
/* .accept = */ nullptr, /* .accept = */ nullptr,
/* .apply = */ llama_sampler_temp_apply, /* .apply = */ llama_sampler_temp_apply,
/* .reset = */ nullptr, /* .reset = */ nullptr,
/* .clone = */ llama_sampler_temp_clone, /* .clone = */ llama_sampler_temp_clone,
/* .free = */ llama_sampler_temp_free, /* .free = */ llama_sampler_temp_free,
/* .apply_ggml = */ llama_sampler_temp_apply_ggml, /* .backend_init = */ nullptr,
/* .accept_ggml = */ nullptr, /* .backend_accept = */ nullptr,
/* .set_input_ggml = */ nullptr, /* .backend_apply = */ llama_sampler_temp_backend_apply,
/* .init_ggml = */ nullptr, /* .backend_set_input = */ nullptr,
}; };
struct llama_sampler * llama_sampler_init_temp(float temp) { struct llama_sampler * llama_sampler_init_temp(float temp) {
@ -1634,16 +1603,16 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
} }
static struct llama_sampler_i llama_sampler_temp_ext_i = { static struct llama_sampler_i llama_sampler_temp_ext_i = {
/* .name = */ llama_sampler_temp_ext_name, /* .name = */ llama_sampler_temp_ext_name,
/* .accept = */ nullptr, /* .accept = */ nullptr,
/* .apply = */ llama_sampler_temp_ext_apply, /* .apply = */ llama_sampler_temp_ext_apply,
/* .reset = */ nullptr, /* .reset = */ nullptr,
/* .clone = */ llama_sampler_temp_ext_clone, /* .clone = */ llama_sampler_temp_ext_clone,
/* .free = */ llama_sampler_temp_ext_free, /* .free = */ llama_sampler_temp_ext_free,
/* .apply_ggml = */ nullptr, /* .backend_init = */ nullptr,
/* .accept_ggml = */ nullptr, /* .backend_accept = */ nullptr,
/* .set_input_ggml = */ nullptr, /* .backend_apply = */ nullptr,
/* .init_ggml = */ nullptr, /* .backend_set_input = */ nullptr,
}; };
struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) { struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
@ -1732,16 +1701,16 @@ static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
} }
static struct llama_sampler_i llama_sampler_xtc_i = { static struct llama_sampler_i llama_sampler_xtc_i = {
/* .name = */ llama_sampler_xtc_name, /* .name = */ llama_sampler_xtc_name,
/* .accept = */ nullptr, /* .accept = */ nullptr,
/* .apply = */ llama_sample_xtc_apply, /* .apply = */ llama_sample_xtc_apply,
/* .reset = */ llama_sampler_xtc_reset, /* .reset = */ llama_sampler_xtc_reset,
/* .clone = */ llama_sampler_xtc_clone, /* .clone = */ llama_sampler_xtc_clone,
/* .free = */ llama_sampler_xtc_free, /* .free = */ llama_sampler_xtc_free,
/* .apply_ggml = */ nullptr, /* .backend_init = */ nullptr,
/* .accept_ggml = */ nullptr, /* .backend_accept = */ nullptr,
/* .set_input_ggml = */ nullptr, /* .backend_apply = */ nullptr,
/* .init_ggml = */ nullptr, /* .backend_set_input = */ nullptr,
}; };
struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) { struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
@ -1844,16 +1813,16 @@ static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
} }
static struct llama_sampler_i llama_sampler_mirostat_i = { static struct llama_sampler_i llama_sampler_mirostat_i = {
/* .name = */ llama_sampler_mirostat_name, /* .name = */ llama_sampler_mirostat_name,
/* .accept = */ nullptr, /* .accept = */ nullptr,
/* .apply = */ llama_sampler_mirostat_apply, /* .apply = */ llama_sampler_mirostat_apply,
/* .reset = */ llama_sampler_mirostat_reset, /* .reset = */ llama_sampler_mirostat_reset,
/* .clone = */ llama_sampler_mirostat_clone, /* .clone = */ llama_sampler_mirostat_clone,
/* .free = */ llama_sampler_mirostat_free, /* .free = */ llama_sampler_mirostat_free,
/* .apply_ggml = */ nullptr, /* .backend_init = */ nullptr,
/* .accept_ggml = */ nullptr, /* .backend_accept = */ nullptr,
/* .set_input_ggml = */ nullptr, /* .backend_apply = */ nullptr,
/* .init_ggml = */ nullptr, /* .backend_set_input = */ nullptr,
}; };
struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) { struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
@ -1947,16 +1916,16 @@ static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) {
} }
static struct llama_sampler_i llama_sampler_mirostat_v2_i = { static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
/* .name = */ llama_sampler_mirostat_v2_name, /* .name = */ llama_sampler_mirostat_v2_name,
/* .accept = */ nullptr, /* .accept = */ nullptr,
/* .apply = */ llama_sampler_mirostat_v2_apply, /* .apply = */ llama_sampler_mirostat_v2_apply,
/* .reset = */ llama_sampler_mirostat_v2_reset, /* .reset = */ llama_sampler_mirostat_v2_reset,
/* .clone = */ llama_sampler_mirostat_v2_clone, /* .clone = */ llama_sampler_mirostat_v2_clone,
/* .free = */ llama_sampler_mirostat_v2_free, /* .free = */ llama_sampler_mirostat_v2_free,
/* .apply_ggml = */ nullptr, /* .backend_init = */ nullptr,
/* .accept_ggml = */ nullptr, /* .backend_accept = */ nullptr,
/* .set_input_ggml = */ nullptr, /* .backend_apply = */ nullptr,
/* .init_ggml = */ nullptr, /* .backend_set_input = */ nullptr,
}; };
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) { struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
@ -2068,16 +2037,16 @@ static void llama_sampler_grammar_free(struct llama_sampler * smpl) {
} }
static struct llama_sampler_i llama_sampler_grammar_i = { static struct llama_sampler_i llama_sampler_grammar_i = {
/* .name = */ llama_sampler_grammar_name, /* .name = */ llama_sampler_grammar_name,
/* .accept = */ llama_sampler_grammar_accept_impl, /* .accept = */ llama_sampler_grammar_accept_impl,
/* .apply = */ llama_sampler_grammar_apply, /* .apply = */ llama_sampler_grammar_apply,
/* .reset = */ llama_sampler_grammar_reset, /* .reset = */ llama_sampler_grammar_reset,
/* .clone = */ llama_sampler_grammar_clone, /* .clone = */ llama_sampler_grammar_clone,
/* .free = */ llama_sampler_grammar_free, /* .free = */ llama_sampler_grammar_free,
/* .apply_ggml = */ nullptr, /* .backend_init = */ nullptr,
/* .accept_ggml = */ nullptr, /* .backend_accept = */ nullptr,
/* .set_input_ggml = */ nullptr, /* .backend_apply = */ nullptr,
/* .init_ggml = */ nullptr, /* .backend_set_input = */ nullptr,
}; };
static struct llama_sampler * llama_sampler_init_grammar_impl( static struct llama_sampler * llama_sampler_init_grammar_impl(
@ -2279,16 +2248,16 @@ static void llama_sampler_penalties_free(struct llama_sampler * smpl) {
} }
static struct llama_sampler_i llama_sampler_penalties_i = { static struct llama_sampler_i llama_sampler_penalties_i = {
/* .name = */ llama_sampler_penalties_name, /* .name = */ llama_sampler_penalties_name,
/* .accept = */ llama_sampler_penalties_accept, /* .accept = */ llama_sampler_penalties_accept,
/* .apply = */ llama_sampler_penalties_apply, /* .apply = */ llama_sampler_penalties_apply,
/* .reset = */ llama_sampler_penalties_reset, /* .reset = */ llama_sampler_penalties_reset,
/* .clone = */ llama_sampler_penalties_clone, /* .clone = */ llama_sampler_penalties_clone,
/* .free = */ llama_sampler_penalties_free, /* .free = */ llama_sampler_penalties_free,
/* .apply_ggml = */ nullptr, /* .backend_init = */ nullptr,
/* .accept_ggml = */ nullptr, /* .backend_accept = */ nullptr,
/* .set_input_ggml = */ nullptr, /* .backend_apply = */ nullptr,
/* .init_ggml = */ nullptr, /* .backend_set_input = */ nullptr,
}; };
struct llama_sampler * llama_sampler_init_penalties( struct llama_sampler * llama_sampler_init_penalties(
@ -2374,16 +2343,16 @@ static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) {
} }
static struct llama_sampler_i llama_sampler_top_n_sigma_i = { static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
/* .name = */ llama_sampler_top_n_sigma_name, /* .name = */ llama_sampler_top_n_sigma_name,
/* .accept = */ nullptr, /* .accept = */ nullptr,
/* .apply = */ llama_sampler_top_n_sigma_apply, /* .apply = */ llama_sampler_top_n_sigma_apply,
/* .reset = */ nullptr, /* .reset = */ nullptr,
/* .clone = */ llama_sampler_top_n_sigma_clone, /* .clone = */ llama_sampler_top_n_sigma_clone,
/* .free = */ llama_sampler_top_n_sigma_free, /* .free = */ llama_sampler_top_n_sigma_free,
/* .apply_ggml = */ nullptr, /* .backend_init = */ nullptr,
/* .accept_ggml = */ nullptr, /* .backend_accept = */ nullptr,
/* .set_input_ggml = */ nullptr, /* .backend_apply = */ nullptr,
/* .init_ggml = */ nullptr, /* .backend_set_input = */ nullptr,
}; };
struct llama_sampler * llama_sampler_init_top_n_sigma(float n) { struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
@ -2708,16 +2677,16 @@ static void llama_sampler_dry_free(struct llama_sampler * smpl) {
} }
static struct llama_sampler_i llama_sampler_dry_i = { static struct llama_sampler_i llama_sampler_dry_i = {
/* .name = */ llama_sampler_dry_name, /* .name = */ llama_sampler_dry_name,
/* .accept = */ llama_sampler_dry_accept, /* .accept = */ llama_sampler_dry_accept,
/* .apply = */ llama_sampler_dry_apply, /* .apply = */ llama_sampler_dry_apply,
/* .reset = */ llama_sampler_dry_reset, /* .reset = */ llama_sampler_dry_reset,
/* .clone = */ llama_sampler_dry_clone, /* .clone = */ llama_sampler_dry_clone,
/* .free = */ llama_sampler_dry_free, /* .free = */ llama_sampler_dry_free,
/* .apply_ggml = */ nullptr, /* .backend_init = */ nullptr,
/* .accept_ggml = */ nullptr, /* .backend_accept = */ nullptr,
/* .set_input_ggml = */ nullptr, /* .backend_apply = */ nullptr,
/* .init_ggml = */ nullptr, /* .backend_set_input = */ nullptr,
}; };
struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t n_ctx_train, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) { struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t n_ctx_train, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
@ -2857,11 +2826,11 @@ static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) {
delete (llama_sampler_logit_bias *) smpl->ctx; delete (llama_sampler_logit_bias *) smpl->ctx;
} }
static void llama_sampler_logit_bias_apply_ggml( static void llama_sampler_logit_bias_backend_apply(
struct llama_sampler * smpl, struct llama_sampler * smpl,
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_cgraph * gf, struct ggml_cgraph * gf,
struct llama_sampler_ggml_data * ggml_data) { struct llama_sampler_backend_data * data) {
GGML_UNUSED(gf); GGML_UNUSED(gf);
GGML_UNUSED(ctx); GGML_UNUSED(ctx);
@ -2871,11 +2840,11 @@ static void llama_sampler_logit_bias_apply_ggml(
} }
// Add the sparse logit logit_bias to the logits // Add the sparse logit logit_bias to the logits
struct ggml_tensor * logit_biased = ggml_add_inplace(ctx, ggml_data->logits, sctx->inp_logit_bias); struct ggml_tensor * logit_biased = ggml_add_inplace(ctx, data->logits, sctx->inp_logit_bias);
ggml_build_forward_expand(gf, logit_biased); ggml_build_forward_expand(gf, logit_biased);
} }
static void llama_sampler_logit_bias_set_input_ggml(struct llama_sampler * smpl) { static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * smpl) {
auto * sctx = (llama_sampler_logit_bias *) smpl->ctx; auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
if (sctx->logit_bias.empty()) { if (sctx->logit_bias.empty()) {
return; return;
@ -2892,7 +2861,7 @@ static void llama_sampler_logit_bias_set_input_ggml(struct llama_sampler * smpl)
ggml_backend_tensor_set(sctx->inp_logit_bias, logit_bias_sparse.data(), 0, ggml_nbytes(sctx->inp_logit_bias)); ggml_backend_tensor_set(sctx->inp_logit_bias, logit_bias_sparse.data(), 0, ggml_nbytes(sctx->inp_logit_bias));
} }
static void llama_sampler_logit_bias_init_ggml( static void llama_sampler_logit_bias_backend_init(
struct llama_sampler * smpl, struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) { ggml_backend_buffer_type_t buft) {
auto * sctx = (llama_sampler_logit_bias *) smpl->ctx; auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
@ -2918,16 +2887,16 @@ static void llama_sampler_logit_bias_init_ggml(
} }
static struct llama_sampler_i llama_sampler_logit_bias_i = { static struct llama_sampler_i llama_sampler_logit_bias_i = {
/* .name = */ llama_sampler_logit_bias_name, /* .name = */ llama_sampler_logit_bias_name,
/* .accept = */ nullptr, /* .accept = */ nullptr,
/* .apply = */ llama_sampler_logit_bias_apply, /* .apply = */ llama_sampler_logit_bias_apply,
/* .reset = */ nullptr, /* .reset = */ nullptr,
/* .clone = */ llama_sampler_logit_bias_clone, /* .clone = */ llama_sampler_logit_bias_clone,
/* .free = */ llama_sampler_logit_bias_free, /* .free = */ llama_sampler_logit_bias_free,
/* .apply_ggml = */ llama_sampler_logit_bias_apply_ggml, /* .backend_init = */ llama_sampler_logit_bias_backend_init,
/* .accept_ggml = */ nullptr, /* .backend_accept = */ nullptr,
/* .set_input_ggml = */ llama_sampler_logit_bias_set_input_ggml, /* .backend_apply = */ llama_sampler_logit_bias_backend_apply,
/* .init_ggml = */ llama_sampler_logit_bias_init_ggml, /* .backend_set_input = */ llama_sampler_logit_bias_backend_set_input,
}; };
struct llama_sampler * llama_sampler_init_logit_bias( struct llama_sampler * llama_sampler_init_logit_bias(
@ -3155,16 +3124,16 @@ static void llama_sampler_infill_free(struct llama_sampler * smpl) {
} }
static struct llama_sampler_i llama_sampler_infill_i = { static struct llama_sampler_i llama_sampler_infill_i = {
/* .name = */ llama_sampler_infill_name, /* .name = */ llama_sampler_infill_name,
/* .accept = */ nullptr, /* .accept = */ nullptr,
/* .apply = */ llama_sampler_infill_apply, /* .apply = */ llama_sampler_infill_apply,
/* .reset = */ nullptr, /* .reset = */ nullptr,
/* .clone = */ llama_sampler_infill_clone, /* .clone = */ llama_sampler_infill_clone,
/* .free = */ llama_sampler_infill_free, /* .free = */ llama_sampler_infill_free,
/* .apply_ggml = */ nullptr, /* .backend_apply = */ nullptr,
/* .accept_ggml = */ nullptr, /* .backend_accept = */ nullptr,
/* .set_input_ggml = */ nullptr, /* .backend_set_input = */ nullptr,
/* .init_ggml = */ nullptr, /* .backend_init = */ nullptr,
}; };
struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) { struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {

View File

@ -24,9 +24,9 @@ struct llama_sampler_chain {
}; };
struct llama_sampler * llama_sampler_init_dry_testing( struct llama_sampler * llama_sampler_init_dry_testing(
int32_t context_size, int32_t context_size,
float dry_multiplier, float dry_multiplier,
float dry_base, float dry_base,
int32_t dry_allowed_length, int32_t dry_allowed_length,
int32_t dry_penalty_last_n, int32_t dry_penalty_last_n,
const std::vector<std::vector<llama_token>>& seq_breakers); const std::vector<std::vector<llama_token>> & seq_breakers);

View File

@ -345,7 +345,7 @@ static void test_backend_top_k_sampling(const char * model_path) {
// sampling, first top_k on the backend and then dist on the CPU. // sampling, first top_k on the backend and then dist on the CPU.
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
struct llama_sampler * chain = llama_sampler_chain_init(chain_params); struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
GGML_ASSERT(chain->iface->apply_ggml != nullptr); GGML_ASSERT(chain->iface->backend_apply != nullptr);
llama_sampler_chain_add(chain, llama_sampler_init_dist(18)); llama_sampler_chain_add(chain, llama_sampler_init_dist(18));
llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx);