diff --git a/include/llama.h b/include/llama.h index 263733cf2c..24cd5be4a5 100644 --- a/include/llama.h +++ b/include/llama.h @@ -210,13 +210,6 @@ extern "C" { bool sorted; // note: do not assume the data is sorted - always check this flag } llama_token_data_array; - struct llama_sampler_ggml_data { - struct ggml_tensor * logits; - struct ggml_tensor * probs; - struct ggml_tensor * sampled; - struct ggml_tensor * candidates; - }; - typedef bool (*llama_progress_callback)(float progress, void * user_data); // Input data for llama_encode/llama_decode @@ -1181,11 +1174,16 @@ extern "C" { // // llama_sampler_free(smpl); // - // TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU). - // typedef void * llama_sampler_context_t; + struct llama_sampler_backend_data { + struct ggml_tensor * logits; + struct ggml_tensor * probs; + struct ggml_tensor * sampled; + struct ggml_tensor * candidates; + }; + // user code can implement the interface below in order to create custom llama_sampler struct llama_sampler_i { const char * (*name) (const struct llama_sampler * smpl); // can be NULL @@ -1195,25 +1193,28 @@ extern "C" { struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL - void (*apply_ggml)( struct llama_sampler * smpl, - struct ggml_context * ctx, - struct ggml_cgraph * gf, - struct llama_sampler_ggml_data * ggml_data); + // backend sampling interface + void (*backend_init)(struct llama_sampler * smpl, ggml_backend_buffer_type_t buft); - void (*accept_ggml)( struct llama_sampler * smpl, - struct ggml_context * ctx, - struct ggml_cgraph * gf, - struct ggml_tensor * selected_token); + void (*backend_accept)( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct ggml_tensor * selected_token); - void (*set_input_ggml)(struct llama_sampler * smpl); + void (*backend_apply)( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_backend_data * ggml_data); - void (*init_ggml)(struct llama_sampler * smpl, - ggml_backend_buffer_type_t buft); + void (*backend_set_input)(struct llama_sampler * smpl); }; struct llama_sampler { const struct llama_sampler_i * iface; - llama_sampler_context_t ctx; + + llama_sampler_context_t ctx; }; LLAMA_API void llama_set_backend_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl); @@ -1228,17 +1229,6 @@ extern "C" { // important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add) LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl); - LLAMA_API void llama_sampler_init_ggml (struct llama_sampler * smpl, ggml_backend_buffer_type_t buft); - LLAMA_API void llama_sampler_set_input_ggml(struct llama_sampler * smpl); - LLAMA_API void llama_sampler_apply_ggml (struct llama_sampler * smpl, - struct ggml_context * ctx, - struct ggml_cgraph * gf, - struct llama_sampler_ggml_data * ggml_data); - LLAMA_API void llama_sampler_accept_ggml (struct llama_sampler * smpl, - struct ggml_context * ctx, - struct ggml_cgraph * gf, - struct ggml_tensor * selected_token); - // llama_sampler_chain // a type of llama_sampler that can chain multiple samplers one after another diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 73a37773f7..3d88dcd296 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -465,8 +465,8 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); for (const auto & [seq_id, sampler] : samplers) { - if (sampler->iface->set_input_ggml) { - sampler->iface->set_input_ggml(sampler); + if (sampler->iface->backend_set_input) { + sampler->iface->backend_set_input(sampler); } } } @@ -2088,8 +2088,9 @@ void llm_graph_context::build_sampling() const { const int32_t row_idx = it->second; // Allow GPU sampler to create input tensors by implementing init_ggml. - if (sampler->iface->init_ggml != nullptr) { - sampler->iface->init_ggml(sampler, buft); + // TODO: this should not be done here + if (sampler->iface->backend_init != nullptr) { + sampler->iface->backend_init(sampler, buft); } active_samplers[seq_id] = sampler; @@ -2097,33 +2098,34 @@ void llm_graph_context::build_sampling() const { ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, n_vocab, row_idx * logits_t->nb[1]); ggml_format_name(logits_seq, "logits_seq_%d", seq_id); - struct llama_sampler_ggml_data ggml_data = { + struct llama_sampler_backend_data data = { /*.logits =*/ logits_seq, /*.probs =*/ nullptr, /*.sampled =*/ nullptr, /*.candidates =*/ nullptr, }; - llama_sampler_apply_ggml(sampler, ctx0, gf, &ggml_data); + assert(sampler->iface->backend_apply); + sampler->iface->backend_apply(sampler, ctx0, gf, &data); - if (ggml_data.sampled != nullptr) { - res->t_sampled[seq_id] = ggml_data.sampled; - ggml_build_forward_expand(gf, ggml_data.sampled); + if (data.sampled != nullptr) { + res->t_sampled[seq_id] = data.sampled; + ggml_build_forward_expand(gf, data.sampled); } - if (ggml_data.probs != nullptr) { - res->t_sampled_probs[seq_id] = ggml_data.probs; - ggml_build_forward_expand(gf, ggml_data.probs); + if (data.probs != nullptr) { + res->t_sampled_probs[seq_id] = data.probs; + ggml_build_forward_expand(gf, data.probs); } - if (ggml_data.logits != logits_seq) { - res->t_sampled_logits[seq_id] = ggml_data.logits; + if (data.logits != logits_seq) { + res->t_sampled_logits[seq_id] = data.logits; ggml_build_forward_expand(gf, res->t_sampled_logits[seq_id]); } - if (ggml_data.candidates != nullptr) { - res->t_candidates[seq_id] = ggml_data.candidates; - ggml_build_forward_expand(gf, ggml_data.candidates); + if (data.candidates != nullptr) { + res->t_candidates[seq_id] = data.candidates; + ggml_build_forward_expand(gf, data.candidates); } } diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index a13be03240..8069aa6802 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -348,7 +348,9 @@ static uint32_t get_rng_seed(uint32_t seed) { // llama_sampler API -struct llama_sampler * llama_sampler_init(const struct llama_sampler_i * iface, llama_sampler_context_t ctx) { +struct llama_sampler * llama_sampler_init( + const struct llama_sampler_i * iface, + llama_sampler_context_t ctx) { return new llama_sampler { /* .iface = */ iface, /* .ctx = */ ctx, @@ -374,39 +376,6 @@ void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_ar smpl->iface->apply(smpl, cur_p); } -void llama_sampler_apply_ggml( - struct llama_sampler * smpl, - struct ggml_context * ctx, - struct ggml_cgraph * gf, - struct llama_sampler_ggml_data * ggml_data) { - GGML_ASSERT(smpl->iface->apply_ggml); - smpl->iface->apply_ggml(smpl, ctx, gf, ggml_data); -} - -void llama_sampler_accept_ggml( - struct llama_sampler * smpl, - ggml_context * ctx, - ggml_cgraph * gf, - struct ggml_tensor * selected_token) { - if (smpl->iface->accept_ggml) { - smpl->iface->accept_ggml(smpl, ctx, gf, selected_token); - } -} - -void llama_sampler_set_input_ggml(struct llama_sampler * smpl) { - if (smpl->iface->set_input_ggml) { - smpl->iface->set_input_ggml(smpl); - } -} - -void llama_sampler_init_ggml( - struct llama_sampler * smpl, - ggml_backend_buffer_type_t buft) { - if (smpl->iface->init_ggml) { - smpl->iface->init_ggml(smpl, buft); - } -} - void llama_sampler_reset(struct llama_sampler * smpl) { if (smpl->iface->reset) { smpl->iface->reset(smpl); @@ -523,10 +492,10 @@ static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_d time_meas tm(chain->t_sample_us, chain->params.no_perf); for (auto * smpl : chain->samplers) { - // Skip GPU samplers - they have apply_ggml but no apply if (smpl->iface->apply == nullptr) { continue; } + llama_sampler_apply(smpl, cur_p); } } @@ -561,21 +530,19 @@ static void llama_sampler_chain_free(struct llama_sampler * smpl) { delete chain; } -static void llama_sampler_chain_apply_ggml( - struct llama_sampler * smpl, - struct ggml_context * ctx, - struct ggml_cgraph * gf, - struct llama_sampler_ggml_data * ggml_data) { +static void llama_sampler_chain_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { auto * chain = (llama_sampler_chain *) smpl->ctx; for (auto * smpl : chain->samplers) { - if (smpl->iface->apply_ggml) { - smpl->iface->apply_ggml(smpl, ctx, gf, ggml_data); + if (smpl->iface->backend_init) { + smpl->iface->backend_init(smpl,buft); } } } -static void llama_sampler_chain_accept_ggml( +static void llama_sampler_chain_backend_accept( struct llama_sampler * smpl, ggml_context * ctx, ggml_cgraph * gf, @@ -583,45 +550,47 @@ static void llama_sampler_chain_accept_ggml( auto * chain = (llama_sampler_chain *) smpl->ctx; for (auto * smpl : chain->samplers) { - if (smpl->iface->accept_ggml) { - smpl->iface->accept_ggml(smpl, ctx, gf, selected_token); + if (smpl->iface->backend_accept) { + smpl->iface->backend_accept(smpl, ctx, gf, selected_token); } } } -static void llama_sampler_chain_set_input_ggml(struct llama_sampler * smpl) { +static void llama_sampler_chain_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_backend_data * data) { auto * chain = (llama_sampler_chain *) smpl->ctx; for (auto * smpl : chain->samplers) { - if (smpl->iface->set_input_ggml) { - smpl->iface->set_input_ggml(smpl); + if (smpl->iface->backend_apply) { + smpl->iface->backend_apply(smpl, ctx, gf, data); } } } -static void llama_sampler_chain_set_backend_context( - struct llama_sampler * smpl, - ggml_backend_buffer_type_t buft) { +static void llama_sampler_chain_backend_set_input(struct llama_sampler * smpl) { auto * chain = (llama_sampler_chain *) smpl->ctx; for (auto * smpl : chain->samplers) { - if (smpl->iface->init_ggml) { - smpl->iface->init_ggml(smpl,buft); + if (smpl->iface->backend_set_input) { + smpl->iface->backend_set_input(smpl); } } } static struct llama_sampler_i llama_sampler_chain_i = { - /* .name = */ llama_sampler_chain_name, - /* .accept = */ llama_sampler_chain_accept, - /* .apply = */ llama_sampler_chain_apply, - /* .reset = */ llama_sampler_chain_reset, - /* .clone = */ llama_sampler_chain_clone, - /* .free = */ llama_sampler_chain_free, - /* .apply_ggml = */ llama_sampler_chain_apply_ggml, - /* .accept_ggml = */ llama_sampler_chain_accept_ggml, - /* .set_input_ggml = */ llama_sampler_chain_set_input_ggml, - /* .init_ggml = */ llama_sampler_chain_set_backend_context, + /* .name = */ llama_sampler_chain_name, + /* .accept = */ llama_sampler_chain_accept, + /* .apply = */ llama_sampler_chain_apply, + /* .reset = */ llama_sampler_chain_reset, + /* .clone = */ llama_sampler_chain_clone, + /* .free = */ llama_sampler_chain_free, + /* .backend_init = */ llama_sampler_chain_backend_init, + /* .backend_accept = */ llama_sampler_chain_backend_accept, + /* .backend_apply = */ llama_sampler_chain_backend_apply, + /* .backend_set_input = */ llama_sampler_chain_backend_set_input, }; struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) { @@ -689,29 +658,29 @@ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_to } } -static void llama_sampler_greedy_apply_ggml( +static void llama_sampler_greedy_backend_apply( struct llama_sampler * smpl, struct ggml_context * ctx, struct ggml_cgraph * gf, - struct llama_sampler_ggml_data * ggml_data) { + struct llama_sampler_backend_data * data) { GGML_UNUSED(gf); GGML_UNUSED(smpl); - struct ggml_tensor * argmax_result = ggml_argmax(ctx, ggml_data->logits); + struct ggml_tensor * argmax_result = ggml_argmax(ctx, data->logits); ggml_set_name(argmax_result, "argmax_result"); - ggml_data->sampled = argmax_result; + data->sampled = argmax_result; } static struct llama_sampler_i llama_sampler_greedy_i = { - /* .name = */ llama_sampler_greedy_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_greedy_apply, - /* .reset = */ nullptr, - /* .clone = */ nullptr, - /* .free = */ nullptr, - /* .apply_ggml = */ llama_sampler_greedy_apply_ggml, - /* .accept_ggml = */ nullptr, - /* .set_input_ggml = */ nullptr, - /* .init_ggml = */ nullptr, + /* .name = */ llama_sampler_greedy_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_greedy_apply, + /* .reset = */ nullptr, + /* .clone = */ nullptr, + /* .free = */ nullptr, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_greedy_backend_apply, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_greedy() { @@ -838,15 +807,24 @@ static void llama_sampler_dist_free(struct llama_sampler * smpl) { delete (llama_sampler_dist *) smpl->ctx; } -static void llama_sampler_dist_apply_ggml( - struct llama_sampler * smpl, - struct ggml_context * ctx, - struct ggml_cgraph * gf, - struct llama_sampler_ggml_data * ggml_data) { +static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_dist *) smpl->ctx; + GGML_ASSERT(sctx->inp_uniform != nullptr); + + std::uniform_real_distribution dist(0.0f, 1.0f); + const float rnd = dist(sctx->rng); + ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float)); +} + +static void llama_sampler_dist_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_backend_data * data) { GGML_UNUSED(gf); auto * sctx = (llama_sampler_dist *) smpl->ctx; - struct ggml_tensor * probs = ggml_soft_max(ctx, ggml_data->logits); + struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits); ggml_set_name(probs, "dist_probs"); struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs); @@ -883,8 +861,8 @@ static void llama_sampler_dist_apply_ggml( // Map back to original vocab ids if a candidates tensor is available. struct ggml_tensor * sampled_token = idx; - if (ggml_data->candidates != nullptr) { - struct ggml_tensor * candidates = ggml_data->candidates; + if (data->candidates != nullptr) { + struct ggml_tensor * candidates = data->candidates; struct ggml_tensor * candidates_reshaped = ggml_view_2d(ctx, candidates, 1, ggml_nelements(candidates), ggml_type_size(candidates->type), 0); @@ -893,19 +871,10 @@ static void llama_sampler_dist_apply_ggml( } ggml_set_output(sampled_token); - ggml_data->sampled = sampled_token; + data->sampled = sampled_token; } -static void llama_sampler_dist_set_input_ggml(struct llama_sampler * smpl) { - auto * sctx = (llama_sampler_dist *) smpl->ctx; - GGML_ASSERT(sctx->inp_uniform != nullptr); - - std::uniform_real_distribution dist(0.0f, 1.0f); - const float rnd = dist(sctx->rng); - ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float)); -} - -static void llama_sampler_dist_init_ggml( +static void llama_sampler_dist_backend_init( struct llama_sampler * smpl, ggml_backend_buffer_type_t buft) { auto * sctx = (llama_sampler_dist *) smpl->ctx; @@ -921,7 +890,7 @@ static void llama_sampler_dist_init_ggml( sctx->inp_ctx.reset(ggml_init(params)); // Create the uniform random scalar input tensor. This will be set by - // llama_sampler_dist_set_input_ggml after this graph is built. + // llama_sampler_dist_backend_set_input after this graph is built. sctx->inp_uniform = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1); ggml_set_name(sctx->inp_uniform, "uniform"); ggml_set_input(sctx->inp_uniform); @@ -931,16 +900,16 @@ static void llama_sampler_dist_init_ggml( } static struct llama_sampler_i llama_sampler_dist_i = { - /* .name = */ llama_sampler_dist_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_dist_apply, - /* .reset = */ llama_sampler_dist_reset, - /* .clone = */ llama_sampler_dist_clone, - /* .free = */ llama_sampler_dist_free, - /* .apply_ggml = */ llama_sampler_dist_apply_ggml, - /* .accept_ggml = */ nullptr, - /* .set_input_ggml = */ llama_sampler_dist_set_input_ggml, - /* .init_ggml = */ llama_sampler_dist_init_ggml, + /* .name = */ llama_sampler_dist_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_dist_apply, + /* .reset = */ llama_sampler_dist_reset, + /* .clone = */ llama_sampler_dist_clone, + /* .free = */ llama_sampler_dist_free, + /* .backend_init = */ llama_sampler_dist_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_dist_backend_apply, + /* .backend_set_input = */ llama_sampler_dist_backend_set_input, }; struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { @@ -986,15 +955,22 @@ static void llama_sampler_top_k_free(struct llama_sampler * smpl) { delete (llama_sampler_top_k *) smpl->ctx; } -static void llama_sampler_top_k_apply_ggml( +static void llama_sampler_top_k_backend_init( struct llama_sampler * smpl, - struct ggml_context * ctx, - struct ggml_cgraph * gf, - struct llama_sampler_ggml_data * ggml_data) { + ggml_backend_buffer_type_t buft) { + auto * ctx_data = (llama_sampler_top_k *) smpl->ctx; + ctx_data->device = ggml_backend_buft_get_device(buft); +} + +static void llama_sampler_top_k_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_backend_data * data) { auto * ctx_data = (llama_sampler_top_k *) smpl->ctx; - struct ggml_tensor * top_k = ggml_top_k(ctx, ggml_data->logits, ctx_data->k); + struct ggml_tensor * top_k = ggml_top_k(ctx, data->logits, ctx_data->k); ggml_set_name(top_k, "top_k"); // top_k is a view of argsort - check if backend supports the underlying argsort operation @@ -1004,34 +980,27 @@ static void llama_sampler_top_k_apply_ggml( fprintf(stderr, "CPU backend will be used instead which defeats the purpose of having backend samplers\n"); } - ggml_data->candidates = top_k; + data->candidates = top_k; - struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, ggml_data->logits, 1, ggml_data->logits->ne[0]); + struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]); struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k); ggml_set_name(top_k_rows, "top_k_rows"); - ggml_data->logits = ggml_reshape_1d(ctx, top_k_rows, ctx_data->k); - ggml_build_forward_expand(gf, ggml_data->logits); -} - -static void llama_sampler_top_k_init_ggml( - struct llama_sampler * smpl, - ggml_backend_buffer_type_t buft) { - auto * ctx_data = (llama_sampler_top_k *) smpl->ctx; - ctx_data->device = ggml_backend_buft_get_device(buft); + data->logits = ggml_reshape_1d(ctx, top_k_rows, ctx_data->k); + ggml_build_forward_expand(gf, data->logits); } static struct llama_sampler_i llama_sampler_top_k_i = { - /* .name = */ llama_sampler_top_k_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_top_k_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_top_k_clone, - /* .free = */ llama_sampler_top_k_free, - /* .apply_ggml = */ llama_sampler_top_k_apply_ggml, - /* .accept_ggml = */ nullptr, - /* .set_input_ggml = */ nullptr, - /* .init_ggml = */ llama_sampler_top_k_init_ggml, + /* .name = */ llama_sampler_top_k_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_k_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_k_clone, + /* .free = */ llama_sampler_top_k_free, + /* .backend_init = */ llama_sampler_top_k_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_top_k_backend_apply, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_top_k(int32_t k) { @@ -1124,14 +1093,21 @@ static void llama_sampler_top_p_free(struct llama_sampler * smpl) { delete (llama_sampler_top_p *) smpl->ctx; } -static void llama_sampler_top_p_apply_ggml( - struct llama_sampler * smpl, - struct ggml_context * ctx, - struct ggml_cgraph * gf, - struct llama_sampler_ggml_data * ggml_data) { +static void llama_sampler_top_p_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_top_p *) smpl->ctx; + sctx->device = ggml_backend_buft_get_device(buft); +} + +static void llama_sampler_top_p_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_backend_data * data) { auto * sctx = (llama_sampler_top_p *) smpl->ctx; - struct ggml_tensor * softmax = ggml_soft_max(ctx, ggml_data->logits); + struct ggml_tensor * softmax = ggml_soft_max(ctx, data->logits); ggml_set_name(softmax, "top_p_softmax"); // Get the sorted indices of the softmax probabilities in descending order. @@ -1181,30 +1157,23 @@ static void llama_sampler_top_p_apply_ggml( struct ggml_tensor * top_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val); ggml_set_name(top_p_bias, "top_p_bias"); - ggml_data->logits = ggml_add(ctx, ggml_data->logits, top_p_bias); - ggml_set_name(ggml_data->logits, "top_p_logits"); + data->logits = ggml_add(ctx, data->logits, top_p_bias); + ggml_set_name(data->logits, "top_p_logits"); - ggml_build_forward_expand(gf, ggml_data->logits); -} - -static void llama_sampler_top_p_init_ggml( - struct llama_sampler * smpl, - ggml_backend_buffer_type_t buft) { - auto * sctx = (llama_sampler_top_p *) smpl->ctx; - sctx->device = ggml_backend_buft_get_device(buft); + ggml_build_forward_expand(gf, data->logits); } static struct llama_sampler_i llama_sampler_top_p_i = { - /* .name = */ llama_sampler_top_p_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_top_p_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_top_p_clone, - /* .free = */ llama_sampler_top_p_free, - /* .apply_ggml = */ llama_sampler_top_p_apply_ggml, - /* .accept_ggml = */ nullptr, - /* .set_input_ggml = */ nullptr, - /* .init_ggml = */ llama_sampler_top_p_init_ggml, + /* .name = */ llama_sampler_top_p_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_p_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_p_clone, + /* .free = */ llama_sampler_top_p_free, + /* .backend_init = */ llama_sampler_top_p_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_top_p_backend_apply, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) { @@ -1296,17 +1265,24 @@ static void llama_sampler_min_p_free(struct llama_sampler * smpl) { delete (llama_sampler_min_p *) smpl->ctx; } -static void llama_sampler_min_p_apply_ggml( - struct llama_sampler * smpl, - struct ggml_context * ctx, - struct ggml_cgraph * gf, - struct llama_sampler_ggml_data * ggml_data) { +static void llama_sampler_min_p_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_min_p *) smpl->ctx; + sctx->device = ggml_backend_buft_get_device(buft); +} + +static void llama_sampler_min_p_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_backend_data * data) { auto * sctx = (llama_sampler_min_p *) smpl->ctx; - struct ggml_tensor * max_idx = ggml_argmax(ctx, ggml_data->logits); + struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits); ggml_set_name(max_idx, "max_idx"); - struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, ggml_data->logits, 1, ggml_data->logits->ne[0]); + struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]); ggml_set_name(logits_rows, "logits_rows"); struct ggml_tensor * max_logit = ggml_get_rows(ctx, logits_rows, max_idx); @@ -1317,7 +1293,7 @@ static void llama_sampler_min_p_apply_ggml( ggml_set_name(threshold, "min_p_threshold"); // Subtract the threshold from logits. - struct ggml_tensor * sub = ggml_sub(ctx, ggml_data->logits, threshold); + struct ggml_tensor * sub = ggml_sub(ctx, data->logits, threshold); // Create a mask where logits below the threshold are 0 (discard), // and others are 1 (keep). @@ -1333,30 +1309,23 @@ static void llama_sampler_min_p_apply_ggml( ggml_set_name(min_p_bias, "min_p_bias"); // Add the min_p bias to the logits. - ggml_data->logits = ggml_add(ctx, ggml_data->logits, min_p_bias); - ggml_set_name(ggml_data->logits, "min_p_logits"); + data->logits = ggml_add(ctx, data->logits, min_p_bias); + ggml_set_name(data->logits, "min_p_logits"); - ggml_build_forward_expand(gf, ggml_data->logits); -} - -static void llama_sampler_min_p_init_ggml( - struct llama_sampler * smpl, - ggml_backend_buffer_type_t buft) { - auto * sctx = (llama_sampler_min_p *) smpl->ctx; - sctx->device = ggml_backend_buft_get_device(buft); + ggml_build_forward_expand(gf, data->logits); } static struct llama_sampler_i llama_sampler_min_p_i = { - /* .name = */ llama_sampler_min_p_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_min_p_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_min_p_clone, - /* .free = */ llama_sampler_min_p_free, - /* .apply_ggml = */ llama_sampler_min_p_apply_ggml, - /* .accept_ggml = */ nullptr, - /* .set_input_ggml = */ nullptr, - /* .init_ggml = */ llama_sampler_min_p_init_ggml, + /* .name = */ llama_sampler_min_p_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_min_p_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_min_p_clone, + /* .free = */ llama_sampler_min_p_free, + /* .backend_init = */ llama_sampler_min_p_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_min_p_backend_apply, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) { @@ -1451,16 +1420,16 @@ static void llama_sampler_typical_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_typical_i = { - /* .name = */ llama_sampler_typical_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_typical_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_typical_clone, - /* .free = */ llama_sampler_typical_free, - /* .apply_ggml = */ nullptr, - /* .accept_ggml = */ nullptr, - /* .set_input_ggml = */ nullptr, - /* .init_ggml = */ nullptr, + /* .name = */ llama_sampler_typical_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_typical_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_typical_clone, + /* .free = */ llama_sampler_typical_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) { @@ -1498,38 +1467,38 @@ static void llama_sampler_temp_free(struct llama_sampler * smpl) { delete (llama_sampler_temp *) smpl->ctx; } -static void llama_sampler_temp_apply_ggml( +static void llama_sampler_temp_backend_apply( struct llama_sampler * smpl, struct ggml_context * ctx, struct ggml_cgraph * gf, - struct llama_sampler_ggml_data * ggml_data) { + struct llama_sampler_backend_data * data) { auto * ctx_data = (llama_sampler_temp *) smpl->ctx; if (ctx_data->temp <= 0.0f) { return; } - struct ggml_tensor * scaled = ggml_scale(ctx, ggml_data->logits, 1.0f / ctx_data->temp); + struct ggml_tensor * scaled = ggml_scale(ctx, data->logits, 1.0f / ctx_data->temp); ggml_set_name(scaled, "temp_scaled"); // Make sure the scaled tensor is contiguous for subsequent operations - ggml_data->logits = ggml_cont(ctx, scaled); - ggml_set_name(ggml_data->logits, "temp_scaled_logits"); + data->logits = ggml_cont(ctx, scaled); + ggml_set_name(data->logits, "temp_scaled_logits"); - ggml_build_forward_expand(gf, ggml_data->logits); + ggml_build_forward_expand(gf, data->logits); } static struct llama_sampler_i llama_sampler_temp_i = { - /* .name = */ llama_sampler_temp_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_temp_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_temp_clone, - /* .free = */ llama_sampler_temp_free, - /* .apply_ggml = */ llama_sampler_temp_apply_ggml, - /* .accept_ggml = */ nullptr, - /* .set_input_ggml = */ nullptr, - /* .init_ggml = */ nullptr, + /* .name = */ llama_sampler_temp_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_temp_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_temp_clone, + /* .free = */ llama_sampler_temp_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_temp_backend_apply, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_temp(float temp) { @@ -1634,16 +1603,16 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_temp_ext_i = { - /* .name = */ llama_sampler_temp_ext_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_temp_ext_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_temp_ext_clone, - /* .free = */ llama_sampler_temp_ext_free, - /* .apply_ggml = */ nullptr, - /* .accept_ggml = */ nullptr, - /* .set_input_ggml = */ nullptr, - /* .init_ggml = */ nullptr, + /* .name = */ llama_sampler_temp_ext_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_temp_ext_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_temp_ext_clone, + /* .free = */ llama_sampler_temp_ext_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) { @@ -1732,16 +1701,16 @@ static void llama_sampler_xtc_reset(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_xtc_i = { - /* .name = */ llama_sampler_xtc_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sample_xtc_apply, - /* .reset = */ llama_sampler_xtc_reset, - /* .clone = */ llama_sampler_xtc_clone, - /* .free = */ llama_sampler_xtc_free, - /* .apply_ggml = */ nullptr, - /* .accept_ggml = */ nullptr, - /* .set_input_ggml = */ nullptr, - /* .init_ggml = */ nullptr, + /* .name = */ llama_sampler_xtc_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sample_xtc_apply, + /* .reset = */ llama_sampler_xtc_reset, + /* .clone = */ llama_sampler_xtc_clone, + /* .free = */ llama_sampler_xtc_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) { @@ -1844,16 +1813,16 @@ static void llama_sampler_mirostat_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_mirostat_i = { - /* .name = */ llama_sampler_mirostat_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_mirostat_apply, - /* .reset = */ llama_sampler_mirostat_reset, - /* .clone = */ llama_sampler_mirostat_clone, - /* .free = */ llama_sampler_mirostat_free, - /* .apply_ggml = */ nullptr, - /* .accept_ggml = */ nullptr, - /* .set_input_ggml = */ nullptr, - /* .init_ggml = */ nullptr, + /* .name = */ llama_sampler_mirostat_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_mirostat_apply, + /* .reset = */ llama_sampler_mirostat_reset, + /* .clone = */ llama_sampler_mirostat_clone, + /* .free = */ llama_sampler_mirostat_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) { @@ -1947,16 +1916,16 @@ static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_mirostat_v2_i = { - /* .name = */ llama_sampler_mirostat_v2_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_mirostat_v2_apply, - /* .reset = */ llama_sampler_mirostat_v2_reset, - /* .clone = */ llama_sampler_mirostat_v2_clone, - /* .free = */ llama_sampler_mirostat_v2_free, - /* .apply_ggml = */ nullptr, - /* .accept_ggml = */ nullptr, - /* .set_input_ggml = */ nullptr, - /* .init_ggml = */ nullptr, + /* .name = */ llama_sampler_mirostat_v2_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_mirostat_v2_apply, + /* .reset = */ llama_sampler_mirostat_v2_reset, + /* .clone = */ llama_sampler_mirostat_v2_clone, + /* .free = */ llama_sampler_mirostat_v2_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) { @@ -2068,16 +2037,16 @@ static void llama_sampler_grammar_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_grammar_i = { - /* .name = */ llama_sampler_grammar_name, - /* .accept = */ llama_sampler_grammar_accept_impl, - /* .apply = */ llama_sampler_grammar_apply, - /* .reset = */ llama_sampler_grammar_reset, - /* .clone = */ llama_sampler_grammar_clone, - /* .free = */ llama_sampler_grammar_free, - /* .apply_ggml = */ nullptr, - /* .accept_ggml = */ nullptr, - /* .set_input_ggml = */ nullptr, - /* .init_ggml = */ nullptr, + /* .name = */ llama_sampler_grammar_name, + /* .accept = */ llama_sampler_grammar_accept_impl, + /* .apply = */ llama_sampler_grammar_apply, + /* .reset = */ llama_sampler_grammar_reset, + /* .clone = */ llama_sampler_grammar_clone, + /* .free = */ llama_sampler_grammar_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; static struct llama_sampler * llama_sampler_init_grammar_impl( @@ -2279,16 +2248,16 @@ static void llama_sampler_penalties_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_penalties_i = { - /* .name = */ llama_sampler_penalties_name, - /* .accept = */ llama_sampler_penalties_accept, - /* .apply = */ llama_sampler_penalties_apply, - /* .reset = */ llama_sampler_penalties_reset, - /* .clone = */ llama_sampler_penalties_clone, - /* .free = */ llama_sampler_penalties_free, - /* .apply_ggml = */ nullptr, - /* .accept_ggml = */ nullptr, - /* .set_input_ggml = */ nullptr, - /* .init_ggml = */ nullptr, + /* .name = */ llama_sampler_penalties_name, + /* .accept = */ llama_sampler_penalties_accept, + /* .apply = */ llama_sampler_penalties_apply, + /* .reset = */ llama_sampler_penalties_reset, + /* .clone = */ llama_sampler_penalties_clone, + /* .free = */ llama_sampler_penalties_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_penalties( @@ -2374,16 +2343,16 @@ static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_top_n_sigma_i = { - /* .name = */ llama_sampler_top_n_sigma_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_top_n_sigma_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_top_n_sigma_clone, - /* .free = */ llama_sampler_top_n_sigma_free, - /* .apply_ggml = */ nullptr, - /* .accept_ggml = */ nullptr, - /* .set_input_ggml = */ nullptr, - /* .init_ggml = */ nullptr, + /* .name = */ llama_sampler_top_n_sigma_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_n_sigma_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_n_sigma_clone, + /* .free = */ llama_sampler_top_n_sigma_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_top_n_sigma(float n) { @@ -2708,16 +2677,16 @@ static void llama_sampler_dry_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_dry_i = { - /* .name = */ llama_sampler_dry_name, - /* .accept = */ llama_sampler_dry_accept, - /* .apply = */ llama_sampler_dry_apply, - /* .reset = */ llama_sampler_dry_reset, - /* .clone = */ llama_sampler_dry_clone, - /* .free = */ llama_sampler_dry_free, - /* .apply_ggml = */ nullptr, - /* .accept_ggml = */ nullptr, - /* .set_input_ggml = */ nullptr, - /* .init_ggml = */ nullptr, + /* .name = */ llama_sampler_dry_name, + /* .accept = */ llama_sampler_dry_accept, + /* .apply = */ llama_sampler_dry_apply, + /* .reset = */ llama_sampler_dry_reset, + /* .clone = */ llama_sampler_dry_clone, + /* .free = */ llama_sampler_dry_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t n_ctx_train, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) { @@ -2857,11 +2826,11 @@ static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) { delete (llama_sampler_logit_bias *) smpl->ctx; } -static void llama_sampler_logit_bias_apply_ggml( +static void llama_sampler_logit_bias_backend_apply( struct llama_sampler * smpl, struct ggml_context * ctx, struct ggml_cgraph * gf, - struct llama_sampler_ggml_data * ggml_data) { + struct llama_sampler_backend_data * data) { GGML_UNUSED(gf); GGML_UNUSED(ctx); @@ -2871,11 +2840,11 @@ static void llama_sampler_logit_bias_apply_ggml( } // Add the sparse logit logit_bias to the logits - struct ggml_tensor * logit_biased = ggml_add_inplace(ctx, ggml_data->logits, sctx->inp_logit_bias); + struct ggml_tensor * logit_biased = ggml_add_inplace(ctx, data->logits, sctx->inp_logit_bias); ggml_build_forward_expand(gf, logit_biased); } -static void llama_sampler_logit_bias_set_input_ggml(struct llama_sampler * smpl) { +static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * smpl) { auto * sctx = (llama_sampler_logit_bias *) smpl->ctx; if (sctx->logit_bias.empty()) { return; @@ -2892,7 +2861,7 @@ static void llama_sampler_logit_bias_set_input_ggml(struct llama_sampler * smpl) ggml_backend_tensor_set(sctx->inp_logit_bias, logit_bias_sparse.data(), 0, ggml_nbytes(sctx->inp_logit_bias)); } -static void llama_sampler_logit_bias_init_ggml( +static void llama_sampler_logit_bias_backend_init( struct llama_sampler * smpl, ggml_backend_buffer_type_t buft) { auto * sctx = (llama_sampler_logit_bias *) smpl->ctx; @@ -2918,16 +2887,16 @@ static void llama_sampler_logit_bias_init_ggml( } static struct llama_sampler_i llama_sampler_logit_bias_i = { - /* .name = */ llama_sampler_logit_bias_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_logit_bias_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_logit_bias_clone, - /* .free = */ llama_sampler_logit_bias_free, - /* .apply_ggml = */ llama_sampler_logit_bias_apply_ggml, - /* .accept_ggml = */ nullptr, - /* .set_input_ggml = */ llama_sampler_logit_bias_set_input_ggml, - /* .init_ggml = */ llama_sampler_logit_bias_init_ggml, + /* .name = */ llama_sampler_logit_bias_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_logit_bias_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_logit_bias_clone, + /* .free = */ llama_sampler_logit_bias_free, + /* .backend_init = */ llama_sampler_logit_bias_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_logit_bias_backend_apply, + /* .backend_set_input = */ llama_sampler_logit_bias_backend_set_input, }; struct llama_sampler * llama_sampler_init_logit_bias( @@ -3155,16 +3124,16 @@ static void llama_sampler_infill_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_infill_i = { - /* .name = */ llama_sampler_infill_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_infill_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_infill_clone, - /* .free = */ llama_sampler_infill_free, - /* .apply_ggml = */ nullptr, - /* .accept_ggml = */ nullptr, - /* .set_input_ggml = */ nullptr, - /* .init_ggml = */ nullptr, + /* .name = */ llama_sampler_infill_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_infill_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_infill_clone, + /* .free = */ llama_sampler_infill_free, + /* .backend_apply = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_set_input = */ nullptr, + /* .backend_init = */ nullptr, }; struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) { diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 759dd7dcb7..80ea22ac35 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -24,9 +24,9 @@ struct llama_sampler_chain { }; struct llama_sampler * llama_sampler_init_dry_testing( - int32_t context_size, - float dry_multiplier, - float dry_base, - int32_t dry_allowed_length, - int32_t dry_penalty_last_n, - const std::vector>& seq_breakers); + int32_t context_size, + float dry_multiplier, + float dry_base, + int32_t dry_allowed_length, + int32_t dry_penalty_last_n, + const std::vector> & seq_breakers); diff --git a/tests/test-backend-sampler.cpp b/tests/test-backend-sampler.cpp index d6839c8805..918766994b 100644 --- a/tests/test-backend-sampler.cpp +++ b/tests/test-backend-sampler.cpp @@ -345,7 +345,7 @@ static void test_backend_top_k_sampling(const char * model_path) { // sampling, first top_k on the backend and then dist on the CPU. struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); struct llama_sampler * chain = llama_sampler_chain_init(chain_params); - GGML_ASSERT(chain->iface->apply_ggml != nullptr); + GGML_ASSERT(chain->iface->backend_apply != nullptr); llama_sampler_chain_add(chain, llama_sampler_init_dist(18)); llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx);