llama : naming

This commit is contained in:
Georgi Gerganov 2025-11-30 00:05:47 +02:00
parent 1760bd69b3
commit c187003d81
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
8 changed files with 136 additions and 136 deletions

View File

@ -123,9 +123,9 @@ struct common_sampler {
} }
void set_logits(struct llama_context * ctx, int idx) { void set_logits(struct llama_context * ctx, int idx) {
const float * sampled_probs = llama_get_backend_sampled_probs_ith (ctx, idx); const float * sampled_probs = llama_get_sampled_probs_ith (ctx, idx);
const float * sampled_logits = llama_get_backend_sampled_logits_ith (ctx, idx); const float * sampled_logits = llama_get_sampled_logits_ith (ctx, idx);
const llama_token * sampled_ids = llama_get_backend_sampled_candidates_ith(ctx, idx); const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx);
const llama_model * model = llama_get_model(ctx); const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
@ -133,13 +133,13 @@ struct common_sampler {
const int n_vocab = llama_vocab_n_tokens(vocab); const int n_vocab = llama_vocab_n_tokens(vocab);
if (sampled_probs) { if (sampled_probs) {
const uint32_t sampled_probs_count = llama_get_backend_sampled_probs_count_ith(ctx, idx); const uint32_t sampled_probs_count = llama_get_sampled_probs_count_ith(ctx, idx);
cur.resize(sampled_probs_count); cur.resize(sampled_probs_count);
for (uint32_t i = 0; i < sampled_probs_count; ++i) { for (uint32_t i = 0; i < sampled_probs_count; ++i) {
cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]}; cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]};
} }
} else if (sampled_logits) { } else if (sampled_logits) {
const uint32_t sampled_logits_count = llama_get_backend_sampled_logits_count_ith(ctx, idx); const uint32_t sampled_logits_count = llama_get_sampled_logits_count_ith(ctx, idx);
cur.resize(sampled_logits_count); cur.resize(sampled_logits_count);
for (uint32_t i = 0; i < sampled_logits_count; i++) { for (uint32_t i = 0; i < sampled_logits_count; i++) {
cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f}; cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f};
@ -536,7 +536,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
// Check if a backend sampler has already sampled a token in which case we // Check if a backend sampler has already sampled a token in which case we
// return that token id directly. // return that token id directly.
{ {
const llama_token id = llama_get_backend_sampled_token_ith(ctx, idx); const llama_token id = llama_get_sampled_token_ith(ctx, idx);
if (id != LLAMA_TOKEN_NULL) { if (id != LLAMA_TOKEN_NULL) {
LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id); LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id);
return id; return id;

View File

@ -979,29 +979,29 @@ extern "C" {
// Get the backend sampled token for the ith token. // Get the backend sampled token for the ith token.
// Returns LLAMA_TOKEN_NULL if no token was sampled. // Returns LLAMA_TOKEN_NULL if no token was sampled.
LLAMA_API llama_token llama_get_backend_sampled_token_ith(struct llama_context * ctx, int32_t i); LLAMA_API llama_token llama_get_sampled_token_ith(struct llama_context * ctx, int32_t i);
// Get the backend sampled probabilites for the ith token // Get the backend sampled probabilites for the ith token
// The index matches llama_get_backend_sampled_token_ith(). // The index matches llama_get_sampled_token_ith().
// Returns NULL if no probabilites were generated. // Returns NULL if no probabilites were generated.
LLAMA_API float * llama_get_backend_sampled_probs_ith(struct llama_context * ctx, int32_t i); LLAMA_API float * llama_get_sampled_probs_ith(struct llama_context * ctx, int32_t i);
// //
// Get the number of backend sampled probabilites for the ith token. // Get the number of backend sampled probabilites for the ith token.
LLAMA_API uint32_t llama_get_backend_sampled_probs_count_ith(struct llama_context * ctx, int32_t i); LLAMA_API uint32_t llama_get_sampled_probs_count_ith(struct llama_context * ctx, int32_t i);
// Get the backend sampled logits for the ith token // Get the backend sampled logits for the ith token
// Returns NULL if no logits were sampled. // Returns NULL if no logits were sampled.
LLAMA_API float * llama_get_backend_sampled_logits_ith(struct llama_context * ctx, int32_t i); LLAMA_API float * llama_get_sampled_logits_ith(struct llama_context * ctx, int32_t i);
// //
// Get the number of backend sampled logits for the ith token. // Get the number of backend sampled logits for the ith token.
LLAMA_API uint32_t llama_get_backend_sampled_logits_count_ith(struct llama_context * ctx, int32_t i); LLAMA_API uint32_t llama_get_sampled_logits_count_ith(struct llama_context * ctx, int32_t i);
// Get the backend sampled candidates (token ids) for the ith token // Get the backend sampled candidates (token ids) for the ith token
// Returns NULL if no candidates were sampled. // Returns NULL if no candidates were sampled.
LLAMA_API llama_token * llama_get_backend_sampled_candidates_ith(struct llama_context * ctx, int32_t i); LLAMA_API llama_token * llama_get_sampled_candidates_ith(struct llama_context * ctx, int32_t i);
// //
// Get the number of backend sampled candidates for the ith token. // Get the number of backend sampled candidates for the ith token.
LLAMA_API uint32_t llama_get_backend_sampled_candidates_count_ith(struct llama_context * ctx, int32_t i); LLAMA_API uint32_t llama_get_sampled_candidates_count_ith(struct llama_context * ctx, int32_t i);
// //
// Vocab // Vocab
@ -1177,7 +1177,7 @@ extern "C" {
typedef void * llama_sampler_context_t; typedef void * llama_sampler_context_t;
struct llama_sampler_backend_data { struct llama_sampler_data {
struct ggml_tensor * logits; struct ggml_tensor * logits;
struct ggml_tensor * probs; struct ggml_tensor * probs;
struct ggml_tensor * sampled; struct ggml_tensor * sampled;
@ -1203,10 +1203,10 @@ extern "C" {
struct ggml_tensor * selected_token); struct ggml_tensor * selected_token);
void (*backend_apply)( void (*backend_apply)(
struct llama_sampler * smpl, struct llama_sampler * smpl,
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_cgraph * gf, struct ggml_cgraph * gf,
struct llama_sampler_backend_data * ggml_data); struct llama_sampler_data * data);
void (*backend_set_input)(struct llama_sampler * smpl); void (*backend_set_input)(struct llama_sampler * smpl);
}; };
@ -1217,7 +1217,7 @@ extern "C" {
llama_sampler_context_t ctx; llama_sampler_context_t ctx;
}; };
LLAMA_API bool llama_set_backend_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl); LLAMA_API bool llama_set_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl);
// mirror of llama_sampler_i: // mirror of llama_sampler_i:
LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx); LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx);

View File

@ -68,7 +68,7 @@ llama_context::llama_context(
for (size_t i = 0; i < params.n_samplers; ++i) { for (size_t i = 0; i < params.n_samplers; ++i) {
const auto & config = params.samplers[i]; const auto & config = params.samplers[i];
if (set_backend_sampler(config.seq_id, config.sampler)) { if (set_sampler(config.seq_id, config.sampler)) {
const int n_samplers = llama_sampler_chain_n(config.sampler); const int n_samplers = llama_sampler_chain_n(config.sampler);
LLAMA_LOG_INFO("%s: setting backend sampler for seq_id %d (n = %d)\n", __func__, config.seq_id, n_samplers); LLAMA_LOG_INFO("%s: setting backend sampler for seq_id %d (n = %d)\n", __func__, config.seq_id, n_samplers);
@ -670,7 +670,7 @@ float * llama_context::get_embeddings() {
return embd; return embd;
} }
llama_token * llama_context::get_backend_sampled_tokens() { llama_token * llama_context::get_sampled_tokens() {
return sampling.sampled; return sampling.sampled;
} }
@ -723,7 +723,7 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
return it->second.data(); return it->second.data();
} }
llama_token llama_context::get_backend_sampled_token_ith(int32_t idx) { llama_token llama_context::get_sampled_token_ith(int32_t idx) {
output_reorder(); output_reorder();
if (sampling.sampled == nullptr) { if (sampling.sampled == nullptr) {
@ -740,7 +740,7 @@ llama_token llama_context::get_backend_sampled_token_ith(int32_t idx) {
} }
} }
float * llama_context::get_backend_sampled_probs_ith(int32_t idx) { float * llama_context::get_sampled_probs_ith(int32_t idx) {
output_reorder(); output_reorder();
if (sampling.probs == nullptr) { if (sampling.probs == nullptr) {
@ -759,7 +759,7 @@ float * llama_context::get_backend_sampled_probs_ith(int32_t idx) {
} }
} }
float * llama_context::get_backend_sampled_logits_ith(int32_t idx) { float * llama_context::get_sampled_logits_ith(int32_t idx) {
output_reorder(); output_reorder();
if (sampling.logits == nullptr) { if (sampling.logits == nullptr) {
@ -778,7 +778,7 @@ float * llama_context::get_backend_sampled_logits_ith(int32_t idx) {
} }
} }
const llama_token * llama_context::get_backend_sampled_candidates_ith(int32_t idx) { const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) {
output_reorder(); output_reorder();
try { try {
@ -795,7 +795,7 @@ const llama_token * llama_context::get_backend_sampled_candidates_ith(int32_t id
return sampling.token_ids_full_vocab.data(); return sampling.token_ids_full_vocab.data();
} }
size_t llama_context::get_backend_sampled_candidates_count(int32_t idx) { size_t llama_context::get_sampled_candidates_count(int32_t idx) {
output_reorder(); output_reorder();
if (sampling.candidates == nullptr) { if (sampling.candidates == nullptr) {
@ -814,7 +814,7 @@ size_t llama_context::get_backend_sampled_candidates_count(int32_t idx) {
} }
} }
size_t llama_context::get_backend_sampled_logits_count(int32_t idx) { size_t llama_context::get_sampled_logits_count(int32_t idx) {
output_reorder(); output_reorder();
if (sampling.logits == nullptr) { if (sampling.logits == nullptr) {
@ -833,7 +833,7 @@ size_t llama_context::get_backend_sampled_logits_count(int32_t idx) {
} }
} }
size_t llama_context::get_backend_sampled_probs_count(int32_t idx) { size_t llama_context::get_sampled_probs_count(int32_t idx) {
output_reorder(); output_reorder();
if (sampling.probs == nullptr) { if (sampling.probs == nullptr) {
@ -909,7 +909,7 @@ void llama_context::set_warmup(bool value) {
cparams.warmup = value; cparams.warmup = value;
} }
bool llama_context::set_backend_sampler(llama_seq_id seq_id, llama_sampler * sampler) { bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler); LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler);
const bool can_offload = const bool can_offload =
@ -2978,10 +2978,10 @@ float * llama_get_logits(llama_context * ctx) {
float * llama_get_logits_ith(llama_context * ctx, int32_t i) { float * llama_get_logits_ith(llama_context * ctx, int32_t i) {
ctx->synchronize(); ctx->synchronize();
if (ctx->get_backend_sampled_token_ith(i) != LLAMA_TOKEN_NULL) { if (ctx->get_sampled_token_ith(i) != LLAMA_TOKEN_NULL) {
return nullptr; return nullptr;
} }
if (ctx->get_backend_sampled_probs_ith(i) != nullptr) { if (ctx->get_sampled_probs_ith(i) != nullptr) {
return nullptr; return nullptr;
} }
@ -3006,50 +3006,50 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
return ctx->get_embeddings_seq(seq_id); return ctx->get_embeddings_seq(seq_id);
} }
bool llama_set_backend_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) { bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) {
return ctx->set_backend_sampler(seq_id, smpl); return ctx->set_sampler(seq_id, smpl);
} }
llama_token llama_get_backend_sampled_token_ith(llama_context * ctx, int32_t i) { llama_token llama_get_sampled_token_ith(llama_context * ctx, int32_t i) {
ctx->synchronize(); ctx->synchronize();
return ctx->get_backend_sampled_token_ith(i); return ctx->get_sampled_token_ith(i);
} }
float * llama_get_backend_sampled_probs_ith(llama_context * ctx, int32_t i) { float * llama_get_sampled_probs_ith(llama_context * ctx, int32_t i) {
ctx->synchronize(); ctx->synchronize();
return ctx->get_backend_sampled_probs_ith(i); return ctx->get_sampled_probs_ith(i);
} }
float * llama_get_backend_sampled_logits_ith(llama_context * ctx, int32_t i) { float * llama_get_sampled_logits_ith(llama_context * ctx, int32_t i) {
ctx->synchronize(); ctx->synchronize();
return ctx->get_backend_sampled_logits_ith(i); return ctx->get_sampled_logits_ith(i);
} }
llama_token * llama_get_backend_sampled_candidates_ith(llama_context * ctx, int32_t i) { llama_token * llama_get_sampled_candidates_ith(llama_context * ctx, int32_t i) {
ctx->synchronize(); ctx->synchronize();
return const_cast<llama_token *>(ctx->get_backend_sampled_candidates_ith(i)); return const_cast<llama_token *>(ctx->get_sampled_candidates_ith(i));
} }
uint32_t llama_get_backend_sampled_candidates_count_ith(llama_context * ctx, int32_t i) { uint32_t llama_get_sampled_candidates_count_ith(llama_context * ctx, int32_t i) {
ctx->synchronize(); ctx->synchronize();
return static_cast<uint32_t>(ctx->get_backend_sampled_candidates_count(i)); return static_cast<uint32_t>(ctx->get_sampled_candidates_count(i));
} }
uint32_t llama_get_backend_sampled_logits_count_ith(llama_context * ctx, int32_t i) { uint32_t llama_get_sampled_logits_count_ith(llama_context * ctx, int32_t i) {
ctx->synchronize(); ctx->synchronize();
return static_cast<uint32_t>(ctx->get_backend_sampled_logits_count(i)); return static_cast<uint32_t>(ctx->get_sampled_logits_count(i));
} }
uint32_t llama_get_backend_sampled_probs_count_ith(llama_context * ctx, int32_t i) { uint32_t llama_get_sampled_probs_count_ith(llama_context * ctx, int32_t i) {
ctx->synchronize(); ctx->synchronize();
return static_cast<uint32_t>(ctx->get_backend_sampled_probs_count(i)); return static_cast<uint32_t>(ctx->get_sampled_probs_count(i));
} }
// llama adapter API // llama adapter API

View File

@ -66,17 +66,17 @@ struct llama_context {
float * get_embeddings_ith(int32_t i); float * get_embeddings_ith(int32_t i);
float * get_embeddings_seq(llama_seq_id seq_id); float * get_embeddings_seq(llama_seq_id seq_id);
llama_token * get_backend_sampled_tokens(); llama_token * get_sampled_tokens();
llama_token get_backend_sampled_token_ith(int32_t idx); llama_token get_sampled_token_ith(int32_t idx);
float * get_backend_sampled_logits_ith(int32_t idx); float * get_sampled_logits_ith(int32_t idx);
size_t get_backend_sampled_logits_count(int32_t idx); size_t get_sampled_logits_count(int32_t idx);
float * get_backend_sampled_probs_ith(int32_t idx); float * get_sampled_probs_ith(int32_t idx);
size_t get_backend_sampled_probs_count(int32_t idx); size_t get_sampled_probs_count(int32_t idx);
const llama_token * get_backend_sampled_candidates_ith(int32_t idx); const llama_token * get_sampled_candidates_ith(int32_t idx);
size_t get_backend_sampled_candidates_count(int32_t idx); size_t get_sampled_candidates_count(int32_t idx);
void attach_threadpool( void attach_threadpool(
ggml_threadpool_t threadpool, ggml_threadpool_t threadpool,
@ -221,7 +221,7 @@ public:
// reserve a graph with a dummy ubatch of the specified size // reserve a graph with a dummy ubatch of the specified size
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false); ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false);
bool set_backend_sampler(llama_seq_id seq_id, llama_sampler * sampler); bool set_sampler(llama_seq_id seq_id, llama_sampler * sampler);
private: private:
llm_graph_params graph_params( llm_graph_params graph_params(

View File

@ -2090,7 +2090,7 @@ void llm_graph_context::build_sampling() const {
ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, n_vocab, row_idx * logits_t->nb[1]); ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, n_vocab, row_idx * logits_t->nb[1]);
ggml_format_name(logits_seq, "logits_seq_%d", seq_id); ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
struct llama_sampler_backend_data data = { struct llama_sampler_data data = {
/*.logits =*/ logits_seq, /*.logits =*/ logits_seq,
/*.probs =*/ nullptr, /*.probs =*/ nullptr,
/*.sampled =*/ nullptr, /*.sampled =*/ nullptr,

View File

@ -410,10 +410,10 @@ void llama_sampler_free(struct llama_sampler * smpl) {
} }
llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) { llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
const llama_token sampled_token = llama_get_backend_sampled_token_ith (ctx, idx); const llama_token sampled_token = llama_get_sampled_token_ith (ctx, idx);
const float * sampled_probs = llama_get_backend_sampled_probs_ith (ctx, idx); const float * sampled_probs = llama_get_sampled_probs_ith (ctx, idx);
const float * sampled_logits = llama_get_backend_sampled_logits_ith (ctx, idx); const float * sampled_logits = llama_get_sampled_logits_ith (ctx, idx);
const llama_token * sampled_ids = llama_get_backend_sampled_candidates_ith(ctx, idx); const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx);
// If a backend sampler has already sampled a token, return it. // If a backend sampler has already sampled a token, return it.
if (sampled_token != LLAMA_TOKEN_NULL) { if (sampled_token != LLAMA_TOKEN_NULL) {
@ -430,13 +430,13 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte
std::vector<llama_token_data> cur; std::vector<llama_token_data> cur;
if (sampled_probs) { if (sampled_probs) {
const uint32_t sampled_probs_count = llama_get_backend_sampled_probs_count_ith(ctx, idx); const uint32_t sampled_probs_count = llama_get_sampled_probs_count_ith(ctx, idx);
cur.resize(sampled_probs_count); cur.resize(sampled_probs_count);
for (uint32_t i = 0; i < sampled_probs_count; ++i) { for (uint32_t i = 0; i < sampled_probs_count; ++i) {
cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]}; cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]};
} }
} else if (sampled_logits) { } else if (sampled_logits) {
const uint32_t sampled_logits_count = llama_get_backend_sampled_logits_count_ith(ctx, idx); const uint32_t sampled_logits_count = llama_get_sampled_logits_count_ith(ctx, idx);
cur.resize(sampled_logits_count); cur.resize(sampled_logits_count);
for (llama_token i = 0; i < (int)sampled_logits_count; i++) { for (llama_token i = 0; i < (int)sampled_logits_count; i++) {
cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f}; cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f};
@ -557,10 +557,10 @@ static void llama_sampler_chain_backend_accept(
} }
static void llama_sampler_chain_backend_apply( static void llama_sampler_chain_backend_apply(
struct llama_sampler * smpl, struct llama_sampler * smpl,
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_cgraph * gf, struct ggml_cgraph * gf,
struct llama_sampler_backend_data * data) { struct llama_sampler_data * data) {
auto * chain = (llama_sampler_chain *) smpl->ctx; auto * chain = (llama_sampler_chain *) smpl->ctx;
for (auto * smpl : chain->samplers) { for (auto * smpl : chain->samplers) {
@ -659,10 +659,10 @@ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_to
} }
static void llama_sampler_greedy_backend_apply( static void llama_sampler_greedy_backend_apply(
struct llama_sampler * smpl, struct llama_sampler * smpl,
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_cgraph * gf, struct ggml_cgraph * gf,
struct llama_sampler_backend_data * data) { struct llama_sampler_data * data) {
GGML_UNUSED(gf); GGML_UNUSED(gf);
GGML_UNUSED(smpl); GGML_UNUSED(smpl);
struct ggml_tensor * argmax_result = ggml_argmax(ctx, data->logits); struct ggml_tensor * argmax_result = ggml_argmax(ctx, data->logits);
@ -817,10 +817,10 @@ static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) {
} }
static void llama_sampler_dist_backend_apply( static void llama_sampler_dist_backend_apply(
struct llama_sampler * smpl, struct llama_sampler * smpl,
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_cgraph * gf, struct ggml_cgraph * gf,
struct llama_sampler_backend_data * data) { struct llama_sampler_data * data) {
GGML_UNUSED(gf); GGML_UNUSED(gf);
auto * sctx = (llama_sampler_dist *) smpl->ctx; auto * sctx = (llama_sampler_dist *) smpl->ctx;
@ -875,8 +875,8 @@ static void llama_sampler_dist_backend_apply(
} }
static void llama_sampler_dist_backend_init( static void llama_sampler_dist_backend_init(
struct llama_sampler * smpl, struct llama_sampler * smpl,
ggml_backend_buffer_type_t buft) { ggml_backend_buffer_type_t buft) {
auto * sctx = (llama_sampler_dist *) smpl->ctx; auto * sctx = (llama_sampler_dist *) smpl->ctx;
sctx->device = ggml_backend_buft_get_device(buft); sctx->device = ggml_backend_buft_get_device(buft);
@ -963,10 +963,10 @@ static void llama_sampler_top_k_backend_init(
} }
static void llama_sampler_top_k_backend_apply( static void llama_sampler_top_k_backend_apply(
struct llama_sampler * smpl, struct llama_sampler * smpl,
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_cgraph * gf, struct ggml_cgraph * gf,
struct llama_sampler_backend_data * data) { struct llama_sampler_data * data) {
auto * ctx_data = (llama_sampler_top_k *) smpl->ctx; auto * ctx_data = (llama_sampler_top_k *) smpl->ctx;
@ -1101,10 +1101,10 @@ static void llama_sampler_top_p_backend_init(
} }
static void llama_sampler_top_p_backend_apply( static void llama_sampler_top_p_backend_apply(
struct llama_sampler * smpl, struct llama_sampler * smpl,
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_cgraph * gf, struct ggml_cgraph * gf,
struct llama_sampler_backend_data * data) { struct llama_sampler_data * data) {
auto * sctx = (llama_sampler_top_p *) smpl->ctx; auto * sctx = (llama_sampler_top_p *) smpl->ctx;
struct ggml_tensor * softmax = ggml_soft_max(ctx, data->logits); struct ggml_tensor * softmax = ggml_soft_max(ctx, data->logits);
@ -1273,10 +1273,10 @@ static void llama_sampler_min_p_backend_init(
} }
static void llama_sampler_min_p_backend_apply( static void llama_sampler_min_p_backend_apply(
struct llama_sampler * smpl, struct llama_sampler * smpl,
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_cgraph * gf, struct ggml_cgraph * gf,
struct llama_sampler_backend_data * data) { struct llama_sampler_data * data) {
auto * sctx = (llama_sampler_min_p *) smpl->ctx; auto * sctx = (llama_sampler_min_p *) smpl->ctx;
struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits); struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits);
@ -1468,10 +1468,10 @@ static void llama_sampler_temp_free(struct llama_sampler * smpl) {
} }
static void llama_sampler_temp_backend_apply( static void llama_sampler_temp_backend_apply(
struct llama_sampler * smpl, struct llama_sampler * smpl,
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_cgraph * gf, struct ggml_cgraph * gf,
struct llama_sampler_backend_data * data) { struct llama_sampler_data * data) {
auto * ctx_data = (llama_sampler_temp *) smpl->ctx; auto * ctx_data = (llama_sampler_temp *) smpl->ctx;
if (ctx_data->temp <= 0.0f) { if (ctx_data->temp <= 0.0f) {
@ -2827,10 +2827,10 @@ static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) {
} }
static void llama_sampler_logit_bias_backend_apply( static void llama_sampler_logit_bias_backend_apply(
struct llama_sampler * smpl, struct llama_sampler * smpl,
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_cgraph * gf, struct ggml_cgraph * gf,
struct llama_sampler_backend_data * data) { struct llama_sampler_data * data) {
GGML_UNUSED(gf); GGML_UNUSED(gf);
GGML_UNUSED(ctx); GGML_UNUSED(ctx);

View File

@ -290,17 +290,17 @@ static void test_backend_greedy_sampling(const char * model_path) {
int32_t batch_idx = test_ctx.idx_for_seq(seq_id); int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx); llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx);
printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str()); printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
token = llama_get_backend_sampled_token_ith(test_ctx.ctx, -1); token = llama_get_sampled_token_ith(test_ctx.ctx, -1);
printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str()); printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
int32_t loop_idx = test_ctx.idx_for_seq(seq_id); int32_t loop_idx = test_ctx.idx_for_seq(seq_id);
llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, loop_idx); llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, loop_idx);
printf("Generation step %d: token id:%d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str()); printf("Generation step %d: token id:%d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str());
if (!test_ctx.decode_token(token, 0)) { if (!test_ctx.decode_token(token, 0)) {
GGML_ASSERT(false && "Failed to decode token"); GGML_ASSERT(false && "Failed to decode token");
@ -328,14 +328,14 @@ static void test_backend_top_k_sampling(const char * model_path) {
int32_t batch_idx = test_ctx.idx_for_seq(seq_id); int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
float * logits = llama_get_backend_sampled_logits_ith(test_ctx.ctx, batch_idx); float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx);
uint32_t n_logits = llama_get_backend_sampled_logits_count_ith(test_ctx.ctx, batch_idx); uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
for (size_t i = 0; i < n_logits; ++i) { for (size_t i = 0; i < n_logits; ++i) {
printf("top_k logit[%zu] = %.6f\n", i, logits[i]); printf("top_k logit[%zu] = %.6f\n", i, logits[i]);
} }
llama_token * candidates = llama_get_backend_sampled_candidates_ith(test_ctx.ctx, batch_idx); llama_token * candidates = llama_get_sampled_candidates_ith(test_ctx.ctx, batch_idx);
uint32_t n_candidates = llama_get_backend_sampled_candidates_count_ith(test_ctx.ctx, batch_idx); uint32_t n_candidates = llama_get_sampled_candidates_count_ith(test_ctx.ctx, batch_idx);
for (size_t i = 0; i < n_candidates; ++i) { for (size_t i = 0; i < n_candidates; ++i) {
printf("top_k candidate[%zu] = %d : %s\n", i, candidates[i], printf("top_k candidate[%zu] = %d : %s\n", i, candidates[i],
test_ctx.token_to_piece(candidates[i], false).c_str()); test_ctx.token_to_piece(candidates[i], false).c_str());
@ -386,7 +386,7 @@ static void test_backend_temp_sampling(const char * model_path) {
// Verfify sequence 0 // Verfify sequence 0
{ {
int32_t batch_idx = test_ctx.idx_for_seq(0); int32_t batch_idx = test_ctx.idx_for_seq(0);
int n_logits = llama_get_backend_sampled_logits_count_ith(test_ctx.ctx, batch_idx); int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
GGML_ASSERT(n_logits == test_ctx.n_vocab); GGML_ASSERT(n_logits == test_ctx.n_vocab);
// Sample from sequence 0 using CPU sampler // Sample from sequence 0 using CPU sampler
@ -443,8 +443,8 @@ static void test_backend_min_p_sampling(const char * model_path) {
int32_t batch_idx = test_ctx.idx_for_seq(seq_id); int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
float * logits = llama_get_backend_sampled_logits_ith(test_ctx.ctx, batch_idx); float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx);
uint32_t n_logits = llama_get_backend_sampled_logits_count_ith(test_ctx.ctx, batch_idx); uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
// Print the logits that are above the min-p threshold // Print the logits that are above the min-p threshold
std::vector<float> filtered_logits; std::vector<float> filtered_logits;
@ -501,8 +501,8 @@ static void test_backend_top_p_sampling(const char * model_path) {
int32_t batch_idx = test_ctx.idx_for_seq(seq_id); int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
float * logits = llama_get_backend_sampled_logits_ith(test_ctx.ctx, batch_idx); float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx);
uint32_t n_logits = llama_get_backend_sampled_logits_count_ith(test_ctx.ctx, batch_idx); uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
// Print the logits that are above the min-p threshold // Print the logits that are above the min-p threshold
std::vector<float> filtered_logits; std::vector<float> filtered_logits;
@ -569,7 +569,7 @@ static void test_backend_multi_sequence_sampling(const char * model_path) {
// Verfiy sequence 0 // Verfiy sequence 0
{ {
int32_t batch_idx = test_ctx.idx_for_seq(0); int32_t batch_idx = test_ctx.idx_for_seq(0);
llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx); llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx);
const std::string token_str = test_ctx.token_to_piece(token, false); const std::string token_str = test_ctx.token_to_piece(token, false);
printf("Seq 0 sampled token id=%d, string='%s'\n", token, token_str.c_str()); printf("Seq 0 sampled token id=%d, string='%s'\n", token, token_str.c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
@ -578,7 +578,7 @@ static void test_backend_multi_sequence_sampling(const char * model_path) {
// Verify sequence 1 // Verify sequence 1
{ {
int32_t batch_idx= test_ctx.idx_for_seq(1); int32_t batch_idx= test_ctx.idx_for_seq(1);
llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx); llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx);
const std::string token_str = test_ctx.token_to_piece(token, false); const std::string token_str = test_ctx.token_to_piece(token, false);
printf("Seq 1 sampled token id=%d, string='%s'\n", token, token_str.c_str()); printf("Seq 1 sampled token id=%d, string='%s'\n", token, token_str.c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
@ -591,7 +591,7 @@ static void test_backend_multi_sequence_sampling(const char * model_path) {
for (llama_seq_id seq_id : {0, 1}) { for (llama_seq_id seq_id : {0, 1}) {
int32_t idx = test_ctx.idx_for_seq(seq_id); int32_t idx = test_ctx.idx_for_seq(seq_id);
llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, idx); llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, idx);
const std::string token_str = test_ctx.token_to_piece(token, false); const std::string token_str = test_ctx.token_to_piece(token, false);
printf(" Seq %d, step %d: token id=%d, string='%s'\n", seq_id, step, token, token_str.c_str()); printf(" Seq %d, step %d: token id=%d, string='%s'\n", seq_id, step, token, token_str.c_str());
tokens[seq_id] = token; tokens[seq_id] = token;
@ -625,12 +625,12 @@ static void test_backend_dist_sampling(const char * model_path) {
} }
int32_t batch_idx = test_ctx.idx_for_seq(seq_id); int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx); llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx);
printf("dist sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str()); printf("dist sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
GGML_ASSERT(llama_get_backend_sampled_logits_ith(test_ctx.ctx, batch_idx) == nullptr); GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx) == nullptr);
token = llama_get_backend_sampled_token_ith(test_ctx.ctx, -1); token = llama_get_sampled_token_ith(test_ctx.ctx, -1);
printf("dist sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str()); printf("dist sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
} }
@ -660,7 +660,7 @@ static void test_backend_dist_sampling_and_cpu(const char * model_path) {
struct llama_sampler * chain = llama_sampler_chain_init(chain_params); struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
llama_sampler_chain_add(chain, llama_sampler_init_dist(18)); llama_sampler_chain_add(chain, llama_sampler_init_dist(18));
llama_token backend_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx); llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx);
llama_token cpu_token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); llama_token cpu_token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx);
printf("dist & cpu sampled id:%d, string:'%s'\n", cpu_token, test_ctx.token_to_piece(cpu_token, false).c_str()); printf("dist & cpu sampled id:%d, string:'%s'\n", cpu_token, test_ctx.token_to_piece(cpu_token, false).c_str());
GGML_ASSERT(backend_token == cpu_token); GGML_ASSERT(backend_token == cpu_token);
@ -707,7 +707,7 @@ static void test_backend_logit_bias_sampling(const char * model_path) {
GGML_ASSERT(false && "Failed to decode token"); GGML_ASSERT(false && "Failed to decode token");
} }
llama_token backend_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, test_ctx.idx_for_seq(seq_id)); llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx, test_ctx.idx_for_seq(seq_id));
const std::string backend_token_str = test_ctx.token_to_piece(backend_token, false); const std::string backend_token_str = test_ctx.token_to_piece(backend_token, false);
printf("logit bias sampled token = %d, string='%s'\n", backend_token, backend_token_str.c_str()); printf("logit bias sampled token = %d, string='%s'\n", backend_token, backend_token_str.c_str());
GGML_ASSERT(backend_token == bias_token); GGML_ASSERT(backend_token == bias_token);
@ -748,22 +748,22 @@ static void test_backend_mixed_sampling(const char * model_path) {
// Verfiy sequence 0 that used the dist backend sampler. // Verfiy sequence 0 that used the dist backend sampler.
{ {
int32_t batch_idx = test_ctx.idx_for_seq(0); int32_t batch_idx = test_ctx.idx_for_seq(0);
llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx); llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx);
const std::string token_str = test_ctx.token_to_piece(token, false); const std::string token_str = test_ctx.token_to_piece(token, false);
printf("sampled token id=%d, string='%s'\n", token, token_str.c_str()); printf("sampled token id=%d, string='%s'\n", token, token_str.c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
GGML_ASSERT(llama_get_backend_sampled_logits_ith(test_ctx.ctx, batch_idx) == nullptr); GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx) == nullptr);
GGML_ASSERT(llama_get_backend_sampled_logits_count_ith(test_ctx.ctx, batch_idx) == 0); GGML_ASSERT(llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx) == 0);
} }
// Verfiy sequence 1 that used the top-k backend sampler. // Verfiy sequence 1 that used the top-k backend sampler.
{ {
int32_t batch_idx = test_ctx.idx_for_seq(1); int32_t batch_idx = test_ctx.idx_for_seq(1);
float * logits = llama_get_backend_sampled_logits_ith(test_ctx.ctx, batch_idx); float * logits = llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx);
GGML_ASSERT(logits != nullptr); GGML_ASSERT(logits != nullptr);
size_t n_logits = llama_get_backend_sampled_logits_count_ith(test_ctx.ctx, batch_idx); size_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
GGML_ASSERT(n_logits == (size_t) k); GGML_ASSERT(n_logits == (size_t) k);
GGML_ASSERT(llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx) == LLAMA_TOKEN_NULL); GGML_ASSERT(llama_get_sampled_token_ith(test_ctx.ctx, batch_idx) == LLAMA_TOKEN_NULL);
} }
printf("backend mixed sampling test PASSED\n"); printf("backend mixed sampling test PASSED\n");
@ -790,12 +790,12 @@ static void test_backend_set_sampler(const char * model_path) {
int32_t batch_idx = test_ctx.idx_for_seq(seq_id); int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
// Sample using backend sampler configured above // Sample using backend sampler configured above
llama_token backend_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx); llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx);
const std::string backend_token_str = test_ctx.token_to_piece(backend_token, false); const std::string backend_token_str = test_ctx.token_to_piece(backend_token, false);
printf("dist sampled token = %d, string='%s'\n", backend_token, backend_token_str.c_str()); printf("dist sampled token = %d, string='%s'\n", backend_token, backend_token_str.c_str());
// Now clear the backend sampler for this sequence. // Now clear the backend sampler for this sequence.
llama_set_backend_sampler(test_ctx.ctx, seq_id, nullptr); llama_set_sampler(test_ctx.ctx, seq_id, nullptr);
printf("Cleared backend sampler for seq_id %d\n", seq_id); printf("Cleared backend sampler for seq_id %d\n", seq_id);
// Sample using CPU sampler // Sample using CPU sampler
@ -810,8 +810,8 @@ static void test_backend_set_sampler(const char * model_path) {
// Should not have any sampled token or probs after clearing the backend sampler. // Should not have any sampled token or probs after clearing the backend sampler.
const int32_t idx = test_ctx.idx_for_seq(seq_id); const int32_t idx = test_ctx.idx_for_seq(seq_id);
GGML_ASSERT(llama_get_backend_sampled_token_ith(test_ctx.ctx, idx) == LLAMA_TOKEN_NULL); GGML_ASSERT(llama_get_sampled_token_ith(test_ctx.ctx, idx) == LLAMA_TOKEN_NULL);
GGML_ASSERT(llama_get_backend_sampled_probs_ith(test_ctx.ctx, idx) == nullptr); GGML_ASSERT(llama_get_sampled_probs_ith(test_ctx.ctx, idx) == nullptr);
// Sample the token using the CPU sampler chain. // Sample the token using the CPU sampler chain.
llama_token token2 = llama_sampler_sample(chain, test_ctx.ctx, seq_id); llama_token token2 = llama_sampler_sample(chain, test_ctx.ctx, seq_id);
@ -824,13 +824,13 @@ static void test_backend_set_sampler(const char * model_path) {
struct llama_sampler * new_backend_sampler_chain = llama_sampler_chain_init(new_backend_chain_params); struct llama_sampler * new_backend_sampler_chain = llama_sampler_chain_init(new_backend_chain_params);
llama_sampler_chain_add(new_backend_sampler_chain, llama_sampler_init_top_k(20)); llama_sampler_chain_add(new_backend_sampler_chain, llama_sampler_init_top_k(20));
llama_sampler_chain_add(new_backend_sampler_chain, llama_sampler_init_dist(seed)); llama_sampler_chain_add(new_backend_sampler_chain, llama_sampler_init_dist(seed));
llama_set_backend_sampler(test_ctx.ctx, seq_id, new_backend_sampler_chain); llama_set_sampler(test_ctx.ctx, seq_id, new_backend_sampler_chain);
if (!test_ctx.decode_tokens(tokens2)) { if (!test_ctx.decode_tokens(tokens2)) {
GGML_ASSERT(false && "Failed to decode token"); GGML_ASSERT(false && "Failed to decode token");
} }
llama_token new_backend_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, test_ctx.idx_for_seq(seq_id)); llama_token new_backend_token = llama_get_sampled_token_ith(test_ctx.ctx, test_ctx.idx_for_seq(seq_id));
const std::string new_backend_token_str = test_ctx.token_to_piece(new_backend_token, false); const std::string new_backend_token_str = test_ctx.token_to_piece(new_backend_token, false);
printf("dist sampled token = %d, string='%s'\n", new_backend_token, new_backend_token_str.c_str()); printf("dist sampled token = %d, string='%s'\n", new_backend_token, new_backend_token_str.c_str());
} }
@ -864,7 +864,7 @@ static void test_backend_cpu_mixed_batch(const char * model_path) {
// Verify sequence 0 (backend sampled) // Verify sequence 0 (backend sampled)
{ {
int32_t batch_idx = test_ctx.idx_for_seq(0); int32_t batch_idx = test_ctx.idx_for_seq(0);
llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx); llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx);
const std::string token_str = test_ctx.token_to_piece(token, false); const std::string token_str = test_ctx.token_to_piece(token, false);
printf("Seq 0 (backend) sampled token id=%d, string='%s'\n", token, token_str.c_str()); printf("Seq 0 (backend) sampled token id=%d, string='%s'\n", token, token_str.c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
@ -874,7 +874,7 @@ static void test_backend_cpu_mixed_batch(const char * model_path) {
{ {
int32_t batch_idx = test_ctx.idx_for_seq(1); int32_t batch_idx = test_ctx.idx_for_seq(1);
llama_token backend_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx); llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx);
GGML_ASSERT(backend_token == LLAMA_TOKEN_NULL); GGML_ASSERT(backend_token == LLAMA_TOKEN_NULL);
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
@ -892,7 +892,7 @@ static void test_backend_cpu_mixed_batch(const char * model_path) {
{ {
// clear the backend sampler for seq 0 so that there are no backend // clear the backend sampler for seq 0 so that there are no backend
// samplers. // samplers.
llama_set_backend_sampler(test_ctx.ctx, 0, nullptr); llama_set_sampler(test_ctx.ctx, 0, nullptr);
// Create a CPU sampler and verify we can sampler from it. // Create a CPU sampler and verify we can sampler from it.
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
@ -914,14 +914,14 @@ static void test_backend_cpu_mixed_batch(const char * model_path) {
struct llama_sampler * sampler_chain= llama_sampler_chain_init(chain_params); struct llama_sampler * sampler_chain= llama_sampler_chain_init(chain_params);
llama_sampler_chain_add(sampler_chain, llama_sampler_init_dist(88)); llama_sampler_chain_add(sampler_chain, llama_sampler_init_dist(88));
llama_set_backend_sampler(test_ctx.ctx, 0, sampler_chain); llama_set_sampler(test_ctx.ctx, 0, sampler_chain);
if (!test_ctx.decode_token(3834, 0)) { if (!test_ctx.decode_token(3834, 0)) {
GGML_ASSERT(false && "Failed to decode token"); GGML_ASSERT(false && "Failed to decode token");
} }
int32_t batch_idx = test_ctx.idx_for_seq(0); int32_t batch_idx = test_ctx.idx_for_seq(0);
llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx); llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx);
const std::string token_str = test_ctx.token_to_piece(token, false); const std::string token_str = test_ctx.token_to_piece(token, false);
printf("re-added backend sampled token id=%d, string='%s'\n", token, token_str.c_str()); printf("re-added backend sampled token id=%d, string='%s'\n", token, token_str.c_str());
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);

View File

@ -1014,7 +1014,7 @@ struct server_context_impl {
SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str()); SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str());
llama_sampler * backend_chain = common_sampler_chain_backend(slot.smpl.get()); llama_sampler * backend_chain = common_sampler_chain_backend(slot.smpl.get());
llama_set_backend_sampler(ctx, slot.id, backend_chain); llama_set_sampler(ctx, slot.id, backend_chain);
} }
// initialize draft batch // initialize draft batch