sampling : add support for backend sampling
This commit adds support for performing sampling operations on the backend (e.g. GPU) as part of the model computation graph. The motivation for this feature is to enable sampling to be performed directly on the backend as part of the computation graph being executed, allowing for some or all of the sampling to be done on the backend. For example, the backend sampler chain might select/sample a token directly in which case only the sampled token needs to be transferred from device memory to host memory. It is also possible for the backend samplers to perform filtering of the logits, or compute and filter the probability distribution, in which case only the filtered logits or probabilites need to be transferred back to system memory for further processing by CPU samplers. Currently the backend sampling works in a similar manner to how pooling works, it is a function that is called by build_graph and the sampler operations become part of the models computation graph.
This commit is contained in:
parent
cb623de3fc
commit
7884b0e0ac
|
|
@ -1501,6 +1501,21 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
params.sampling.grammar = json_schema_to_grammar(json::parse(schema));
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--backend-sampling"},
|
||||
"enable backend sampling (default: disabled)",
|
||||
[](common_params & params) {
|
||||
params.sampling.backend_sampling = true;
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--backend-dist"},
|
||||
"perform final (distribution) sampling on backend (default: disabled)",
|
||||
[](common_params & params) {
|
||||
params.sampling.backend_dist = true;
|
||||
params.sampling.backend_sampling = true;
|
||||
}
|
||||
).set_sparam());
|
||||
add_opt(common_arg(
|
||||
{"--pooling"}, "{none,mean,cls,last,rank}",
|
||||
"pooling type for embeddings, use model default if unspecified",
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
#include "common.h"
|
||||
#include "log.h"
|
||||
#include "llama.h"
|
||||
#include "sampling.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cinttypes>
|
||||
|
|
@ -956,6 +957,8 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
auto cparams = common_context_params_to_llama(params);
|
||||
cparams.samplers = params.backend_samplers;
|
||||
cparams.n_samplers = params.n_backend_samplers;
|
||||
|
||||
llama_context * lctx = llama_init_from_model(model, cparams);
|
||||
if (lctx == NULL) {
|
||||
|
|
|
|||
|
|
@ -188,6 +188,10 @@ struct common_params_sampling {
|
|||
std::vector<llama_logit_bias> logit_bias; // logit biases to apply
|
||||
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
|
||||
|
||||
// Backend sampling flags
|
||||
bool backend_sampling = false; // enable backend sampling
|
||||
bool backend_dist = false; // backend performs final sampling (dist)
|
||||
|
||||
// print the parameters into a string
|
||||
std::string print() const;
|
||||
};
|
||||
|
|
@ -512,6 +516,9 @@ struct common_params {
|
|||
bool has_speculative() const {
|
||||
return !speculative.model.path.empty() || !speculative.model.hf_repo.empty();
|
||||
}
|
||||
|
||||
struct llama_sampler_seq_config * backend_samplers = NULL;
|
||||
size_t n_backend_samplers = 0;
|
||||
};
|
||||
|
||||
// call once at the start of a program if it uses libcommon
|
||||
|
|
|
|||
|
|
@ -106,12 +106,16 @@ static void llama_sampler_llg_free(llama_sampler * smpl) {
|
|||
}
|
||||
|
||||
static llama_sampler_i llama_sampler_llg_i = {
|
||||
/* .name = */ llama_sampler_llg_name,
|
||||
/* .accept = */ llama_sampler_llg_accept_impl,
|
||||
/* .apply = */ llama_sampler_llg_apply,
|
||||
/* .reset = */ llama_sampler_llg_reset,
|
||||
/* .clone = */ llama_sampler_llg_clone,
|
||||
/* .free = */ llama_sampler_llg_free,
|
||||
/* .name = */ llama_sampler_llg_name,
|
||||
/* .accept = */ llama_sampler_llg_accept_impl,
|
||||
/* .apply = */ llama_sampler_llg_apply,
|
||||
/* .reset = */ llama_sampler_llg_reset,
|
||||
/* .clone = */ llama_sampler_llg_clone,
|
||||
/* .free = */ llama_sampler_llg_free,
|
||||
/* .apply_ggml = */ NULL,
|
||||
/* .accept_ggml = */ NULL,
|
||||
/* .set_input_ggml = */ NULL,
|
||||
/* .set_backend_context = */ NULL,
|
||||
};
|
||||
|
||||
static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len,
|
||||
|
|
|
|||
|
|
@ -113,23 +113,61 @@ struct common_sampler {
|
|||
llama_token_data_array cur_p;
|
||||
|
||||
void set_logits(struct llama_context * ctx, int idx) {
|
||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||
const float * sampled_probs = llama_get_backend_sampled_probs_ith(ctx, idx);
|
||||
const float * sampled_logits = llama_get_backend_sampled_logits_ith(ctx, idx);
|
||||
const llama_token * sampled_ids = llama_get_backend_sampled_token_ids_ith(ctx, idx);
|
||||
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
const int n_vocab = llama_vocab_n_tokens(vocab);
|
||||
|
||||
cur.resize(n_vocab);
|
||||
// Use the member variable instead of allocating locally
|
||||
cur.clear();
|
||||
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
||||
if (sampled_probs) {
|
||||
const uint32_t sampled_probs_count = llama_get_backend_sampled_probs_count_ith(ctx, idx);
|
||||
cur.reserve(sampled_probs_count);
|
||||
// The backend sampler has filtered the probabilities so we need to use the sampled ids.
|
||||
if (sampled_ids != nullptr) {
|
||||
for (uint32_t i = 0; i < sampled_probs_count; ++i) {
|
||||
cur.emplace_back(llama_token_data{sampled_ids[i], 0.0f, sampled_probs[i]});
|
||||
}
|
||||
} else {
|
||||
for (llama_token token_id = 0; token_id < (int) sampled_probs_count; token_id++) {
|
||||
cur.emplace_back(llama_token_data{token_id, 0.0f, sampled_probs[token_id]});
|
||||
}
|
||||
}
|
||||
} else if (sampled_logits) {
|
||||
const uint32_t sampled_logits_count = llama_get_backend_sampled_logits_count_ith(ctx, idx);
|
||||
cur.reserve(sampled_logits_count);
|
||||
// The backend sampler has filtered the logits so we need to use the sampled ids.
|
||||
if (sampled_ids != nullptr) {
|
||||
for (llama_token i = 0; i < (int)sampled_logits_count; i++) {
|
||||
cur.emplace_back(llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f});
|
||||
}
|
||||
} else {
|
||||
for (llama_token token_id = 0; token_id < (int)sampled_logits_count; token_id++) {
|
||||
cur.emplace_back(llama_token_data{token_id, sampled_logits[token_id], 0.0f});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||
GGML_ASSERT(logits != nullptr);
|
||||
cur.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||
}
|
||||
}
|
||||
|
||||
cur_p = { cur.data(), cur.size(), -1, false };
|
||||
}
|
||||
};
|
||||
|
||||
static bool sampler_enabled(const struct common_params_sampling & params, enum common_sampler_type type) {
|
||||
return std::find(params.samplers.begin(), params.samplers.end(), type) != params.samplers.end();
|
||||
}
|
||||
|
||||
std::string common_params_sampling::print() const {
|
||||
char result[1024];
|
||||
|
||||
|
|
@ -287,6 +325,43 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
return result;
|
||||
}
|
||||
|
||||
struct llama_sampler * common_sampler_backend_init(const struct llama_model * model, const struct common_params_sampling & params) {
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
|
||||
chain_params.no_perf = params.no_perf;
|
||||
|
||||
struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
|
||||
if (!params.backend_sampling) {
|
||||
return chain; // return empty chain
|
||||
}
|
||||
|
||||
const bool enable_temp = params.temp > 0.0f && sampler_enabled(params, COMMON_SAMPLER_TYPE_TEMPERATURE);
|
||||
const bool enable_top_k = params.top_k > 0 && sampler_enabled(params, COMMON_SAMPLER_TYPE_TOP_K);
|
||||
const bool enable_dist = params.backend_dist;
|
||||
|
||||
if (!params.logit_bias.empty()) {
|
||||
llama_sampler_chain_add(chain, llama_sampler_backend_init_logit_bias(
|
||||
llama_vocab_n_tokens(vocab),
|
||||
params.logit_bias.size(),
|
||||
params.logit_bias.data()));
|
||||
}
|
||||
|
||||
if (enable_temp) {
|
||||
llama_sampler_chain_add(chain, llama_sampler_backend_init_temp(params.temp));
|
||||
}
|
||||
|
||||
if (enable_top_k) {
|
||||
llama_sampler_chain_add(chain, llama_sampler_backend_init_top_k(params.top_k));
|
||||
}
|
||||
|
||||
if (enable_dist) {
|
||||
llama_sampler_chain_add(chain, llama_sampler_backend_init_dist(params.seed));
|
||||
}
|
||||
|
||||
return chain;
|
||||
}
|
||||
|
||||
void common_sampler_free(struct common_sampler * gsmpl) {
|
||||
if (gsmpl) {
|
||||
llama_sampler_free(gsmpl->grmr);
|
||||
|
|
@ -337,6 +412,14 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
|
|||
}
|
||||
|
||||
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
|
||||
// Check if a backend sampler has already sampled a token in which case we
|
||||
// return that token id directly.
|
||||
const llama_token backend_sampled_token = llama_get_backend_sampled_token_ith(ctx, idx);
|
||||
if (backend_sampled_token != LLAMA_TOKEN_NULL) {
|
||||
LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, backend_sampled_token);
|
||||
return backend_sampled_token;
|
||||
}
|
||||
|
||||
gsmpl->set_logits(ctx, idx);
|
||||
|
||||
auto & grmr = gsmpl->grmr;
|
||||
|
|
|
|||
|
|
@ -38,6 +38,13 @@ struct common_sampler;
|
|||
|
||||
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params);
|
||||
|
||||
// Create a backend sampler chain from common sampling parameters
|
||||
// Returns a llama_sampler chain configured with backend samplers based on the parameters
|
||||
// This chain can be used per-sequence for backend-based sampling
|
||||
// Note: Only samplers that have backend equivalents will be added to the chain
|
||||
// The returned sampler should be freed with llama_sampler_free()
|
||||
struct llama_sampler * common_sampler_backend_init(const struct llama_model * model, const struct common_params_sampling & params);
|
||||
|
||||
void common_sampler_free(struct common_sampler * gsmpl);
|
||||
|
||||
// if accept_grammar is true, the token is accepted both by the sampling chain and the grammar
|
||||
|
|
|
|||
|
|
@ -210,6 +210,13 @@ extern "C" {
|
|||
bool sorted; // note: do not assume the data is sorted - always check this flag
|
||||
} llama_token_data_array;
|
||||
|
||||
struct llama_sampler_ggml_data {
|
||||
struct ggml_tensor * logits;
|
||||
struct ggml_tensor * probs;
|
||||
struct ggml_tensor * sampled_token;
|
||||
struct ggml_tensor * filtered_ids;
|
||||
};
|
||||
|
||||
typedef bool (*llama_progress_callback)(float progress, void * user_data);
|
||||
|
||||
// Input data for llama_encode/llama_decode
|
||||
|
|
@ -300,6 +307,11 @@ extern "C" {
|
|||
bool no_host; // bypass host buffer allowing extra buffers to be used
|
||||
};
|
||||
|
||||
struct llama_sampler_seq_config {
|
||||
llama_seq_id seq_id;
|
||||
struct llama_sampler * sampler;
|
||||
};
|
||||
|
||||
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
|
||||
// https://github.com/ggml-org/llama.cpp/pull/7544
|
||||
struct llama_context_params {
|
||||
|
|
@ -348,6 +360,10 @@ extern "C" {
|
|||
bool kv_unified; // use a unified buffer across the input sequences when computing the attention
|
||||
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
|
||||
|
||||
// backend sampler chain configuration
|
||||
struct llama_sampler_seq_config * samplers;
|
||||
size_t n_samplers;
|
||||
};
|
||||
|
||||
// model quantization parameters
|
||||
|
|
@ -950,6 +966,29 @@ extern "C" {
|
|||
// otherwise: float[n_embd] (1-dimensional)
|
||||
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
|
||||
|
||||
// Get the backend sampled token for the ith token.
|
||||
// Returns LLAMA_TOKEN_NULL if no token was sampled.
|
||||
LLAMA_API llama_token llama_get_backend_sampled_token_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
// Get the backend sampled probabilites for the ith token
|
||||
// The index matches llama_get_backend_sampled_token_ith().
|
||||
// Returns NULL if no probabilites were generated.
|
||||
LLAMA_API float * llama_get_backend_sampled_probs_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
// Get the backend sampled logits for the ith token
|
||||
// Returns NULL if no logits were sampled.
|
||||
LLAMA_API float * llama_get_backend_sampled_logits_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
// Get the backend sampled token ids associated with the sampled logits for the ith token
|
||||
// Returns NULL if no logits were sampled.
|
||||
LLAMA_API llama_token * llama_get_backend_sampled_token_ids_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
// Get the number of backend sampled logits for the ith token.
|
||||
LLAMA_API uint32_t llama_get_backend_sampled_logits_count_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
// Get the number of backend sampled probabilites for the ith token.
|
||||
LLAMA_API uint32_t llama_get_backend_sampled_probs_count_ith(struct llama_context * ctx, int32_t i);
|
||||
|
||||
//
|
||||
// Vocab
|
||||
//
|
||||
|
|
@ -1135,6 +1174,22 @@ extern "C" {
|
|||
struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL
|
||||
void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL
|
||||
|
||||
void (*apply_ggml)( struct llama_sampler * smpl,
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * gf,
|
||||
struct llama_sampler_ggml_data * ggml_data);
|
||||
|
||||
void (*accept_ggml)( struct llama_sampler * smpl,
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * gf,
|
||||
struct ggml_tensor * selected_token);
|
||||
|
||||
void (*set_input_ggml)(struct llama_sampler * smpl);
|
||||
|
||||
void (*init_ggml)(struct llama_sampler * smpl,
|
||||
ggml_backend_buffer_type_t buft);
|
||||
|
||||
|
||||
// TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph
|
||||
//void (*apply_ggml) (struct llama_sampler * smpl, ...);
|
||||
};
|
||||
|
|
@ -1144,6 +1199,8 @@ extern "C" {
|
|||
llama_sampler_context_t ctx;
|
||||
};
|
||||
|
||||
LLAMA_API void llama_set_backend_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl);
|
||||
|
||||
// mirror of llama_sampler_i:
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx);
|
||||
LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
|
||||
|
|
@ -1153,6 +1210,18 @@ extern "C" {
|
|||
LLAMA_API struct llama_sampler * llama_sampler_clone (const struct llama_sampler * smpl);
|
||||
// important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add)
|
||||
LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl);
|
||||
LLAMA_API void llama_sampler_init_ggml(struct llama_sampler * smpl,
|
||||
ggml_backend_buffer_type_t buft);
|
||||
LLAMA_API void llama_sampler_set_input_ggml(struct llama_sampler * smpl);
|
||||
LLAMA_API void llama_sampler_apply_ggml( struct llama_sampler * smpl,
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * gf,
|
||||
struct llama_sampler_ggml_data * ggml_data);
|
||||
|
||||
LLAMA_API void llama_sampler_accept_ggml( struct llama_sampler * smpl,
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * gf,
|
||||
struct ggml_tensor * selected_token);
|
||||
|
||||
// llama_sampler_chain
|
||||
// a type of llama_sampler that can chain multiple samplers one after another
|
||||
|
|
@ -1166,6 +1235,7 @@ extern "C" {
|
|||
|
||||
// after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed
|
||||
LLAMA_API struct llama_sampler * llama_sampler_chain_remove( struct llama_sampler * chain, int32_t i);
|
||||
LLAMA_API uint64_t llama_sampler_chain_get_version(const struct llama_sampler * chain);
|
||||
|
||||
// available samplers:
|
||||
|
||||
|
|
@ -1299,9 +1369,29 @@ extern "C" {
|
|||
//
|
||||
LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab);
|
||||
|
||||
//
|
||||
// Backend samplers
|
||||
//
|
||||
|
||||
/// @details Greedy sampling on backend - always selects the token with the highest probability
|
||||
LLAMA_API struct llama_sampler * llama_sampler_backend_init_greedy(void);
|
||||
|
||||
/// @details Temperature scaling on backend - scales logits by 1/temperature
|
||||
LLAMA_API struct llama_sampler * llama_sampler_backend_init_temp(float temp);
|
||||
|
||||
/// @details Top-K filtering on backend - keeps only the k tokens with highest probabilities
|
||||
LLAMA_API struct llama_sampler * llama_sampler_backend_init_top_k(int32_t k);
|
||||
|
||||
/// @details Distribution sampling on backend - final sampling step that selects a token
|
||||
LLAMA_API struct llama_sampler * llama_sampler_backend_init_dist(uint32_t seed);
|
||||
|
||||
// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
|
||||
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
|
||||
|
||||
LLAMA_API struct llama_sampler * llama_sampler_backend_init_logit_bias(int32_t n_vocab,
|
||||
int32_t n_logit_bias,
|
||||
const llama_logit_bias * logit_bias);
|
||||
|
||||
/// @details Sample and accept a token from the idx-th output of the last evaluation
|
||||
//
|
||||
// Shorthand for:
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ add_library(llama
|
|||
llama-model.cpp
|
||||
llama-quant.cpp
|
||||
llama-sampling.cpp
|
||||
llama-backend-sampler.cpp
|
||||
llama-vocab.cpp
|
||||
unicode-data.cpp
|
||||
unicode.cpp
|
||||
|
|
|
|||
|
|
@ -0,0 +1,489 @@
|
|||
#include "llama.h"
|
||||
#include "ggml.h"
|
||||
#include <cstdio>
|
||||
#include <chrono>
|
||||
#include <random>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
static void llama_sampler_backend_greedy_apply_ggml(
|
||||
struct llama_sampler * smpl,
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * gf,
|
||||
struct llama_sampler_ggml_data * ggml_data) {
|
||||
GGML_UNUSED(gf);
|
||||
GGML_UNUSED(smpl);
|
||||
struct ggml_tensor * argmax_result = ggml_argmax(ctx, ggml_data->logits);
|
||||
ggml_set_name(argmax_result, "argmax_result");
|
||||
ggml_data->sampled_token = argmax_result;
|
||||
}
|
||||
|
||||
static const char * llama_sampler_backend_greedy_sampler_name(const struct llama_sampler *) {
|
||||
return "test-ggml";
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_backend_greedy_clone(const struct llama_sampler * smpl) {
|
||||
(void) smpl;
|
||||
return llama_sampler_backend_init_greedy();
|
||||
}
|
||||
|
||||
struct llama_sampler * llama_sampler_backend_init_greedy() {
|
||||
static const llama_sampler_i iface = {
|
||||
/*.name =*/ llama_sampler_backend_greedy_sampler_name,
|
||||
/*.accept =*/ nullptr,
|
||||
/*.apply =*/ nullptr,
|
||||
/*.reset =*/ nullptr,
|
||||
/*.clone =*/ llama_sampler_backend_greedy_clone,
|
||||
/*.free =*/ nullptr,
|
||||
/*.apply_ggml =*/ llama_sampler_backend_greedy_apply_ggml,
|
||||
/*.accept_ggml =*/ nullptr,
|
||||
/*.set_input_ggml =*/ nullptr,
|
||||
/*.init_ggml =*/ nullptr,
|
||||
};
|
||||
|
||||
auto * sampler = new llama_sampler {
|
||||
/*.iface =*/ &iface,
|
||||
/*.ctx =*/ nullptr,
|
||||
};
|
||||
|
||||
return sampler;
|
||||
}
|
||||
|
||||
struct llama_sampler_backend_temp_ctx {
|
||||
float temp;
|
||||
};
|
||||
|
||||
|
||||
static void llama_sampler_backend_temp_apply_ggml(
|
||||
struct llama_sampler * smpl,
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * gf,
|
||||
struct llama_sampler_ggml_data * ggml_data) {
|
||||
|
||||
auto * ctx_data = (llama_sampler_backend_temp_ctx *) smpl->ctx;
|
||||
|
||||
if (ctx_data->temp <= 0.0f) {
|
||||
return;
|
||||
}
|
||||
|
||||
struct ggml_tensor * scaled = ggml_scale(ctx, ggml_data->logits, 1.0f / ctx_data->temp);
|
||||
ggml_set_name(scaled, "temp_scaled");
|
||||
|
||||
// Make sure the scaled tensor is contiguous for subsequent operations
|
||||
ggml_data->logits = ggml_cont(ctx, scaled);
|
||||
ggml_set_name(ggml_data->logits, "temp_scaled_logits");
|
||||
|
||||
ggml_build_forward_expand(gf, ggml_data->logits);
|
||||
}
|
||||
|
||||
static const char * llama_sampler_backend_temp_name(const struct llama_sampler *) {
|
||||
return "backend-temp";
|
||||
}
|
||||
|
||||
static void llama_sampler_backend_temp_free(struct llama_sampler * smpl) {
|
||||
auto * ctx_data = (llama_sampler_backend_temp_ctx *) smpl->ctx;
|
||||
delete ctx_data;
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_backend_temp_clone(const struct llama_sampler * smpl) {
|
||||
auto * ctx = (llama_sampler_backend_temp_ctx *) smpl->ctx;
|
||||
return llama_sampler_backend_init_temp(ctx->temp);
|
||||
}
|
||||
|
||||
struct llama_sampler * llama_sampler_backend_init_temp(float temp) {
|
||||
static const llama_sampler_i iface = {
|
||||
/*.name =*/ llama_sampler_backend_temp_name,
|
||||
/*.accept =*/ nullptr,
|
||||
/*.apply =*/ nullptr,
|
||||
/*.reset =*/ nullptr,
|
||||
/*.clone =*/ llama_sampler_backend_temp_clone,
|
||||
/*.free =*/ llama_sampler_backend_temp_free,
|
||||
/*.apply_ggml =*/ llama_sampler_backend_temp_apply_ggml,
|
||||
/*.accept_ggml =*/ nullptr,
|
||||
/*.set_input_ggml =*/ nullptr,
|
||||
/*.set_backend_context =*/ nullptr,
|
||||
};
|
||||
|
||||
auto * ctx_data = new llama_sampler_backend_temp_ctx {
|
||||
/*.temp =*/ temp,
|
||||
};
|
||||
|
||||
auto * sampler = new llama_sampler {
|
||||
/*.iface =*/ &iface,
|
||||
/*.ctx =*/ ctx_data,
|
||||
};
|
||||
|
||||
return sampler;
|
||||
}
|
||||
|
||||
|
||||
struct llama_sampler_backend_top_k_ctx {
|
||||
int32_t k;
|
||||
|
||||
// Only required for checking operation support and can be removed later.
|
||||
ggml_backend_dev_t device;
|
||||
};
|
||||
|
||||
static void llama_sampler_backend_top_k_init_ggml(
|
||||
struct llama_sampler * smpl,
|
||||
ggml_backend_buffer_type_t buft) {
|
||||
auto * ctx_data = (llama_sampler_backend_top_k_ctx *) smpl->ctx;
|
||||
ctx_data->device = ggml_backend_buft_get_device(buft);
|
||||
}
|
||||
|
||||
static void llama_sampler_backend_top_k_apply_ggml(
|
||||
struct llama_sampler * smpl,
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * gf,
|
||||
struct llama_sampler_ggml_data * ggml_data) {
|
||||
|
||||
auto * ctx_data = (llama_sampler_backend_top_k_ctx *) smpl->ctx;
|
||||
|
||||
struct ggml_tensor * top_k = ggml_top_k(ctx, ggml_data->logits, ctx_data->k);
|
||||
ggml_set_name(top_k, "top_k");
|
||||
|
||||
// top_k is a view of argsort - check if backend supports the underlying argsort operation
|
||||
// by checking the source tensor (which is the argsort result)
|
||||
if (ctx_data->device && top_k->src[0] && !ggml_backend_dev_supports_op(ctx_data->device, top_k->src[0])) {
|
||||
fprintf(stderr, "Warning: backend does not support argsort operation required for top-k sampling\n");
|
||||
fprintf(stderr, "CPU backend will be used instead which defeats the purpose of having backend samplers\n");
|
||||
}
|
||||
|
||||
ggml_data->filtered_ids = top_k;
|
||||
|
||||
struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, ggml_data->logits, 1, ggml_data->logits->ne[0]);
|
||||
struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k);
|
||||
ggml_set_name(top_k_rows, "top_k_rows");
|
||||
|
||||
ggml_data->logits = ggml_reshape_1d(ctx, top_k_rows, ctx_data->k);
|
||||
ggml_build_forward_expand(gf, ggml_data->logits);
|
||||
}
|
||||
|
||||
static const char * llama_sampler_backend_top_k_name(const struct llama_sampler *) {
|
||||
return "backend-top-k";
|
||||
}
|
||||
|
||||
static void llama_sampler_backend_top_k_free(struct llama_sampler * smpl) {
|
||||
auto * ctx_data = (llama_sampler_backend_top_k_ctx *) smpl->ctx;
|
||||
delete ctx_data;
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_backend_top_k_clone(const struct llama_sampler * smpl) {
|
||||
auto * ctx = (llama_sampler_backend_top_k_ctx *) smpl->ctx;
|
||||
return llama_sampler_backend_init_top_k(ctx->k);
|
||||
}
|
||||
|
||||
struct llama_sampler * llama_sampler_backend_init_top_k(int32_t k) {
|
||||
static const llama_sampler_i iface = {
|
||||
/*.name =*/ llama_sampler_backend_top_k_name,
|
||||
/*.accept =*/ nullptr,
|
||||
/*.apply =*/ nullptr,
|
||||
/*.reset =*/ nullptr,
|
||||
/*.clone =*/ llama_sampler_backend_top_k_clone,
|
||||
/*.free =*/ llama_sampler_backend_top_k_free,
|
||||
/*.apply_ggml =*/ llama_sampler_backend_top_k_apply_ggml,
|
||||
/*.accept_ggml =*/ nullptr,
|
||||
/*.set_input_ggml =*/ nullptr,
|
||||
/*.init_ggml =*/ llama_sampler_backend_top_k_init_ggml,
|
||||
};
|
||||
|
||||
auto * ctx_data = new llama_sampler_backend_top_k_ctx {
|
||||
/*.k =*/ k,
|
||||
/*.device =*/ nullptr,
|
||||
};
|
||||
|
||||
auto * sampler = new llama_sampler {
|
||||
/*.iface =*/ &iface,
|
||||
/*.ctx =*/ ctx_data,
|
||||
};
|
||||
|
||||
return sampler;
|
||||
}
|
||||
|
||||
|
||||
static uint32_t get_rng_seed(uint32_t seed) {
|
||||
if (seed == LLAMA_DEFAULT_SEED) {
|
||||
// use system clock if std::random_device is not a true RNG
|
||||
static bool is_rd_prng = std::random_device().entropy() == 0;
|
||||
if (is_rd_prng) {
|
||||
return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count();
|
||||
}
|
||||
std::random_device rd;
|
||||
return rd();
|
||||
}
|
||||
return seed;
|
||||
}
|
||||
|
||||
struct llama_sampler_backend_dist_ctx {
|
||||
const uint32_t seed;
|
||||
uint32_t seed_cur;
|
||||
std::mt19937 rng;
|
||||
|
||||
struct ggml_tensor * uniform;
|
||||
struct ggml_context * ctx;
|
||||
ggml_backend_buffer_t buffer;
|
||||
|
||||
// Only required for checking operation support and can be removed later.
|
||||
ggml_backend_dev_t device;
|
||||
};
|
||||
|
||||
static void llama_sampler_backend_dist_init_ggml(
|
||||
struct llama_sampler * smpl,
|
||||
ggml_backend_buffer_type_t buft) {
|
||||
|
||||
auto * sctx = (llama_sampler_backend_dist_ctx *) smpl->ctx;
|
||||
sctx->device = ggml_backend_buft_get_device(buft);
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/ ggml_tensor_overhead(),
|
||||
/*.mem_buffer =*/ nullptr,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
sctx->ctx = ggml_init(params);
|
||||
|
||||
// Create the uniform random scalar input tensor. This will be set by
|
||||
// llama_sampler_backend_dist_set_input_ggml after this graph is built.
|
||||
sctx->uniform = ggml_new_tensor_1d(sctx->ctx, GGML_TYPE_F32, 1);
|
||||
ggml_set_name(sctx->uniform, "uniform");
|
||||
ggml_set_input(sctx->uniform);
|
||||
ggml_set_output(sctx->uniform);
|
||||
|
||||
// Allocate all tensors from our context to the backend
|
||||
sctx->buffer = ggml_backend_alloc_ctx_tensors_from_buft(sctx->ctx, buft);
|
||||
}
|
||||
|
||||
static void llama_sampler_backend_dist_set_input_ggml(struct llama_sampler * smpl) {
|
||||
auto * sctx = (llama_sampler_backend_dist_ctx *) smpl->ctx;
|
||||
GGML_ASSERT(sctx->uniform != nullptr);
|
||||
|
||||
std::uniform_real_distribution<float> dist(0.0f, 1.0f);
|
||||
const float rnd = dist(sctx->rng);
|
||||
ggml_backend_tensor_set(sctx->uniform, &rnd, 0, sizeof(float));
|
||||
}
|
||||
|
||||
static void llama_sampler_backend_dist_apply_ggml(
|
||||
struct llama_sampler * smpl,
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * gf,
|
||||
struct llama_sampler_ggml_data * ggml_data) {
|
||||
GGML_UNUSED(gf);
|
||||
auto * sctx = (llama_sampler_backend_dist_ctx *) smpl->ctx;
|
||||
|
||||
struct ggml_tensor * probs = ggml_soft_max(ctx, ggml_data->logits);
|
||||
ggml_set_name(probs, "dist_probs");
|
||||
|
||||
struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs);
|
||||
if (sctx->device && !ggml_backend_dev_supports_op(sctx->device, cumsum)) {
|
||||
fprintf(stderr, "Warning: backend does not support cumsum operation required for dist sampling\n");
|
||||
fprintf(stderr, "CPU backend will be used instead which defeats the purpose of having backend samplers\n");
|
||||
}
|
||||
ggml_set_name(cumsum, "cumsum");
|
||||
|
||||
// The uniform tensor has a random value and we subtract this tensor with
|
||||
// the cumsum tensor (the uniform tensor will be broadcasted by ggml_sub).
|
||||
// Recall that each entry in cumsum is the cumulative probability up to that
|
||||
// index so values stay negative while the cumulative total is below the
|
||||
// random value, and become zero/positive once the threshold is crossed.
|
||||
struct ggml_tensor * diff = ggml_sub(ctx, cumsum, sctx->uniform);
|
||||
ggml_set_name(diff, "dist_cumsum");
|
||||
|
||||
// The ggml_step function produces a tensor where entries are 1 if the
|
||||
// corresponding entry in diff is > 0, and 0 otherwise. So all values up to
|
||||
// the index where the cumulative probability exceeds the random value are 0,
|
||||
// and all entries after that are 1.
|
||||
struct ggml_tensor * mask = ggml_step(ctx, diff);
|
||||
ggml_set_name(mask, "dist_mask");
|
||||
|
||||
// Taking the sum of the mask gives us the sum of elements after the threshold
|
||||
// we are interested in.
|
||||
struct ggml_tensor * idxf = ggml_sum(ctx, mask);
|
||||
ggml_set_name(idxf, "dist_index_f32");
|
||||
|
||||
// Use ggml_scale_bias to scale the index value by -1 and then add the size
|
||||
// of the mask to that value so we get the correct index ((-1 * idxf) + n).
|
||||
struct ggml_tensor * idx = ggml_cast(ctx, ggml_scale_bias(ctx, idxf, -1.0f, mask->ne[0]), GGML_TYPE_I32);
|
||||
ggml_set_name(idx, "dist_index_i32");
|
||||
|
||||
// Map back to original vocab ids if a filtered id tensor is available.
|
||||
struct ggml_tensor * sampled_token = idx;
|
||||
if (ggml_data->filtered_ids != nullptr) {
|
||||
struct ggml_tensor * filtered_ids = ggml_data->filtered_ids;
|
||||
struct ggml_tensor * filtered_ids_reshaped = ggml_view_2d(ctx, filtered_ids, 1, ggml_nelements(filtered_ids),
|
||||
ggml_type_size(filtered_ids->type), 0);
|
||||
|
||||
sampled_token = ggml_get_rows(ctx, filtered_ids_reshaped, idx);
|
||||
ggml_set_name(sampled_token, "dist_sampled_token");
|
||||
}
|
||||
|
||||
ggml_set_output(sampled_token);
|
||||
ggml_data->sampled_token = sampled_token;
|
||||
}
|
||||
|
||||
static const char * llama_sampler_backend_dist_name(const struct llama_sampler *) {
|
||||
return "backend-dist";
|
||||
}
|
||||
|
||||
static void llama_sampler_backend_dist_free(struct llama_sampler * smpl) {
|
||||
auto * sctx = (llama_sampler_backend_dist_ctx *) smpl->ctx;
|
||||
ggml_backend_buffer_free(sctx->buffer);
|
||||
ggml_free(sctx->ctx);
|
||||
delete sctx;
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_backend_dist_clone(const struct llama_sampler * smpl) {
|
||||
auto * sctx = (llama_sampler_backend_dist_ctx *) smpl->ctx;
|
||||
return llama_sampler_backend_init_dist(sctx->seed);
|
||||
}
|
||||
|
||||
|
||||
struct llama_sampler * llama_sampler_backend_init_dist(uint32_t seed) {
|
||||
static const llama_sampler_i iface = {
|
||||
/*.name =*/ llama_sampler_backend_dist_name,
|
||||
/*.accept =*/ nullptr,
|
||||
/*.apply =*/ nullptr,
|
||||
/*.reset =*/ nullptr,
|
||||
/*.clone =*/ llama_sampler_backend_dist_clone,
|
||||
/*.free =*/ llama_sampler_backend_dist_free,
|
||||
/*.apply_ggml =*/ llama_sampler_backend_dist_apply_ggml,
|
||||
/*.accept_ggml =*/ nullptr,
|
||||
/*.set_input_ggml =*/ llama_sampler_backend_dist_set_input_ggml,
|
||||
/*.init_ggml =*/ llama_sampler_backend_dist_init_ggml,
|
||||
};
|
||||
|
||||
auto seed_cur = get_rng_seed(seed);
|
||||
auto * ctx_data = new llama_sampler_backend_dist_ctx {
|
||||
/*.seed =*/ seed,
|
||||
/*.seed_cur =*/ seed_cur,
|
||||
/*.rng =*/ std::mt19937(seed_cur),
|
||||
/*.uniform =*/ nullptr,
|
||||
/*.ctx =*/ nullptr,
|
||||
/*.buffer =*/ nullptr,
|
||||
/*.device =*/ nullptr,
|
||||
};
|
||||
|
||||
auto * sampler = new llama_sampler {
|
||||
/*.iface =*/ &iface,
|
||||
/*.ctx =*/ ctx_data,
|
||||
};
|
||||
|
||||
return sampler;
|
||||
}
|
||||
|
||||
struct llama_sampler_backend_logit_bias_ctx {
|
||||
const int32_t n_vocab;
|
||||
|
||||
const std::vector<llama_logit_bias> logit_bias;
|
||||
|
||||
struct ggml_tensor * logit_bias_t;
|
||||
struct ggml_context * ctx;
|
||||
ggml_backend_buffer_t buffer;
|
||||
};
|
||||
|
||||
static void llama_sampler_backend_logit_bias_init_ggml(
|
||||
struct llama_sampler * smpl,
|
||||
ggml_backend_buffer_type_t buft) {
|
||||
auto * sctx = (llama_sampler_backend_logit_bias_ctx *) smpl->ctx;
|
||||
if (sctx->logit_bias.empty()) {
|
||||
return;
|
||||
}
|
||||
ggml_init_params params = {
|
||||
/*.mem_size =*/ ggml_tensor_overhead() * sctx->n_vocab * sizeof(float),
|
||||
/*.mem_buffer =*/ nullptr,
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
sctx->ctx = ggml_init(params);
|
||||
|
||||
struct ggml_tensor * logit_bias = ggml_new_tensor_1d(sctx->ctx, GGML_TYPE_F32, sctx->n_vocab);
|
||||
sctx->logit_bias_t = logit_bias;
|
||||
ggml_set_name(sctx->logit_bias_t, "logit_bias");
|
||||
ggml_set_input(sctx->logit_bias_t);
|
||||
ggml_set_output(sctx->logit_bias_t);
|
||||
|
||||
// Allocate all tensors from our context to the backend
|
||||
sctx->buffer = ggml_backend_alloc_ctx_tensors_from_buft(sctx->ctx, buft);
|
||||
}
|
||||
|
||||
static void llama_sampler_backend_logit_bias_set_input_ggml(struct llama_sampler * smpl) {
|
||||
auto * sctx = (llama_sampler_backend_logit_bias_ctx *) smpl->ctx;
|
||||
if (sctx->logit_bias.empty()) {
|
||||
return;
|
||||
}
|
||||
GGML_ASSERT(sctx->logit_bias_t != nullptr);
|
||||
|
||||
// Create a sparse logit_bias vector from the logit_bias entries.
|
||||
std::vector<float> logit_bias_sparse(sctx->n_vocab, 0.0f);
|
||||
for (const auto & lb : sctx->logit_bias) {
|
||||
GGML_ASSERT(lb.token >= 0 && lb.token < (int32_t) sctx->n_vocab);
|
||||
logit_bias_sparse[lb.token] = lb.bias;
|
||||
}
|
||||
|
||||
ggml_backend_tensor_set(sctx->logit_bias_t, logit_bias_sparse.data(), 0, ggml_nbytes(sctx->logit_bias_t));
|
||||
}
|
||||
|
||||
static void llama_sampler_backend_logit_bias_apply_ggml(
|
||||
struct llama_sampler * smpl,
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * gf,
|
||||
struct llama_sampler_ggml_data * ggml_data) {
|
||||
GGML_UNUSED(gf);
|
||||
GGML_UNUSED(ctx);
|
||||
|
||||
auto * sctx = (llama_sampler_backend_logit_bias_ctx *) smpl->ctx;
|
||||
if (sctx->logit_bias_t == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Add the sparse logit logit_bias to the logits
|
||||
struct ggml_tensor * logit_biased = ggml_add_inplace(sctx->ctx, ggml_data->logits, sctx->logit_bias_t);
|
||||
ggml_build_forward_expand(gf, logit_biased);
|
||||
}
|
||||
|
||||
static const char * llama_sampler_backend_logit_bias_name(const struct llama_sampler *) {
|
||||
return "backend-logit_bias";
|
||||
}
|
||||
|
||||
static void llama_sampler_backend_logit_bias_free(struct llama_sampler * smpl) {
|
||||
auto * sctx = (llama_sampler_backend_logit_bias_ctx *) smpl->ctx;
|
||||
ggml_backend_buffer_free(sctx->buffer);
|
||||
ggml_free(sctx->ctx);
|
||||
delete sctx;
|
||||
}
|
||||
|
||||
static struct llama_sampler * llama_sampler_backend_logit_bias_clone(const struct llama_sampler * smpl) {
|
||||
auto * sctx = (llama_sampler_backend_logit_bias_ctx *) smpl->ctx;
|
||||
return llama_sampler_backend_init_logit_bias(sctx->n_vocab,
|
||||
sctx->logit_bias.size(),
|
||||
sctx->logit_bias.data());
|
||||
}
|
||||
|
||||
|
||||
struct llama_sampler * llama_sampler_backend_init_logit_bias(int32_t n_vocab,
|
||||
int32_t n_logit_bias,
|
||||
const llama_logit_bias * logit_bias) {
|
||||
static const llama_sampler_i iface = {
|
||||
/*.name =*/ llama_sampler_backend_logit_bias_name,
|
||||
/*.accept =*/ nullptr,
|
||||
/*.apply =*/ nullptr,
|
||||
/*.reset =*/ nullptr,
|
||||
/*.clone =*/ llama_sampler_backend_logit_bias_clone,
|
||||
/*.free =*/ llama_sampler_backend_logit_bias_free,
|
||||
/*.apply_ggml =*/ llama_sampler_backend_logit_bias_apply_ggml,
|
||||
/*.accept_ggml =*/ nullptr,
|
||||
/*.set_input_ggml =*/ llama_sampler_backend_logit_bias_set_input_ggml,
|
||||
/*.init_ggml =*/ llama_sampler_backend_logit_bias_init_ggml,
|
||||
};
|
||||
|
||||
auto * ctx_data = new llama_sampler_backend_logit_bias_ctx {
|
||||
/*.n_vocab =*/ n_vocab,
|
||||
/*.logit_bias =*/ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
|
||||
/*.logit_bias_t =*/ nullptr,
|
||||
/*.ctx =*/ nullptr,
|
||||
/*.buffer =*/ nullptr,
|
||||
};
|
||||
|
||||
auto * sampler = new llama_sampler {
|
||||
/*.iface =*/ &iface,
|
||||
/*.ctx =*/ ctx_data,
|
||||
};
|
||||
|
||||
return sampler;
|
||||
}
|
||||
|
|
@ -58,6 +58,16 @@ llama_context::llama_context(
|
|||
cparams.cb_eval = params.cb_eval;
|
||||
cparams.cb_eval_user_data = params.cb_eval_user_data;
|
||||
|
||||
// backend samplers
|
||||
if (params.samplers != nullptr && params.n_samplers > 0) {
|
||||
samplers.reserve(params.n_samplers);
|
||||
|
||||
for (size_t i = 0; i < params.n_samplers; ++i) {
|
||||
const auto & config = params.samplers[i];
|
||||
samplers[config.seq_id] = config.sampler;
|
||||
}
|
||||
}
|
||||
|
||||
auto rope_scaling_type = params.rope_scaling_type;
|
||||
if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) {
|
||||
rope_scaling_type = hparams.rope_scaling_type_train;
|
||||
|
|
@ -424,6 +434,10 @@ llama_context::llama_context(
|
|||
|
||||
llama_context::~llama_context() {
|
||||
ggml_opt_free(opt_ctx);
|
||||
// TODO: perhaps use a smart pointer for samplers
|
||||
for (auto const& [seq_id, sampler] : samplers) {
|
||||
llama_sampler_free(sampler);
|
||||
}
|
||||
}
|
||||
|
||||
void llama_context::synchronize() {
|
||||
|
|
@ -610,6 +624,10 @@ float * llama_context::get_embeddings() {
|
|||
return embd;
|
||||
}
|
||||
|
||||
llama_token * llama_context::get_backend_sampled_tokens() {
|
||||
return sampled_tokens;
|
||||
}
|
||||
|
||||
float * llama_context::get_embeddings_ith(int32_t i) {
|
||||
int64_t j = -1;
|
||||
|
||||
|
|
@ -659,6 +677,98 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) {
|
|||
return it->second.data();
|
||||
}
|
||||
|
||||
llama_token llama_context::get_backend_sampled_token_ith(int32_t idx) {
|
||||
// Handle special case where idx == -1 (single sequence exists) which is
|
||||
// a valid index when using common_sampler_sample.
|
||||
if (idx == -1) {
|
||||
if (sampled_tokens_map.size() == 1) {
|
||||
auto it = sampled_tokens_map.begin();
|
||||
return it->second;
|
||||
}
|
||||
return LLAMA_TOKEN_NULL;
|
||||
}
|
||||
|
||||
auto it = sampled_tokens_map.find(idx);
|
||||
if (it == sampled_tokens_map.end()) {
|
||||
return LLAMA_TOKEN_NULL;
|
||||
}
|
||||
|
||||
return it->second;
|
||||
}
|
||||
|
||||
float * llama_context::get_backend_sampled_probs_ith(int32_t idx) {
|
||||
if (idx == -1) {
|
||||
if (sampled_probs_map.size() == 1) {
|
||||
return sampled_probs_map.begin()->second.data();
|
||||
}
|
||||
}
|
||||
|
||||
auto it = sampled_probs_map.find(idx);
|
||||
if (it == sampled_probs_map.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return it->second.data();
|
||||
}
|
||||
|
||||
float * llama_context::get_backend_sampled_logits_ith(int32_t idx) {
|
||||
if (idx == -1) {
|
||||
if (sampled_logits_map.size() == 1) {
|
||||
return sampled_logits_map.begin()->second.data();
|
||||
}
|
||||
}
|
||||
auto it = sampled_logits_map.find(idx);
|
||||
if (it == sampled_logits_map.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return it->second.data();
|
||||
}
|
||||
|
||||
const llama_token * llama_context::get_backend_sampled_token_ids_ith(int32_t idx) {
|
||||
if (idx == -1) {
|
||||
if (sampled_token_ids_map.size() == 1) {
|
||||
return sampled_token_ids_map.begin()->second.data();
|
||||
}
|
||||
}
|
||||
auto it = sampled_token_ids_map.find(idx);
|
||||
if (it == sampled_token_ids_map.end() || it->second.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return it->second.data();
|
||||
}
|
||||
|
||||
size_t llama_context::get_backend_sampled_logits_count(int32_t idx) const {
|
||||
if (idx == -1) {
|
||||
if (sampled_logits_map.size() == 1) {
|
||||
return sampled_logits_map.begin()->second.size();
|
||||
}
|
||||
}
|
||||
auto it = sampled_logits_map.find(idx);
|
||||
if (it == sampled_logits_map.end()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return it->second.size();
|
||||
}
|
||||
|
||||
size_t llama_context::get_backend_sampled_probs_count(int32_t idx) const {
|
||||
if (idx == -1) {
|
||||
if (sampled_probs_map.size() == 1) {
|
||||
return sampled_probs_map.begin()->second.size();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
auto it = sampled_probs_map.find(idx);
|
||||
if (it == sampled_probs_map.end()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return it->second.size();
|
||||
}
|
||||
|
||||
void llama_context::attach_threadpool(
|
||||
ggml_threadpool_t threadpool,
|
||||
ggml_threadpool_t threadpool_batch) {
|
||||
|
|
@ -715,6 +825,37 @@ void llama_context::set_warmup(bool value) {
|
|||
cparams.warmup = value;
|
||||
}
|
||||
|
||||
void llama_context::set_backend_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
|
||||
LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler);
|
||||
|
||||
auto it = samplers.find(seq_id);
|
||||
if (it != samplers.end()) {
|
||||
// If the sampler to be set is the same that is already set, do nothing.
|
||||
if (it->second == sampler) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_sampler_free(it->second);
|
||||
|
||||
// If sampler is nullptr, we remove the samppler chain for this seq_id.
|
||||
// chain for this seq_id.
|
||||
if (sampler == nullptr) {
|
||||
samplers.erase(it);
|
||||
return;
|
||||
}
|
||||
|
||||
// Otherwise, we replace the existing sampler with the new one.
|
||||
it->second = sampler;
|
||||
return;
|
||||
}
|
||||
|
||||
// If there is no sampler for this seq_id and the caller provides a non-null
|
||||
// sampler, we set it.
|
||||
if (sampler != nullptr) {
|
||||
samplers[seq_id] = sampler;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_context::set_adapter_lora(
|
||||
llama_adapter_lora * adapter,
|
||||
float scale) {
|
||||
|
|
@ -1029,6 +1170,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
|
||||
// TODO: this clear of the buffer can easily be forgotten - need something better
|
||||
embd_seq.clear();
|
||||
sampled_probs_map.clear();
|
||||
sampled_logits_map.clear();
|
||||
sampled_tokens_map.clear();
|
||||
sampled_token_ids_map.clear();
|
||||
output_swaps.clear();
|
||||
|
||||
bool did_optimize = false;
|
||||
|
|
@ -1088,6 +1233,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
};
|
||||
|
||||
int64_t n_outputs_prev = 0;
|
||||
// This flag indicates whether a backend sampler has actually sampled a specific
|
||||
// token, or if it has produced probabilites. If true, we true we can skip
|
||||
// the normal copying of logits and embeddings.
|
||||
bool backend_has_sampled = false;
|
||||
|
||||
do {
|
||||
const auto & ubatch = mctx->get_ubatch();
|
||||
|
|
@ -1147,80 +1296,131 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
||||
//}
|
||||
|
||||
auto * t_logits = res->get_logits();
|
||||
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
||||
|
||||
if (t_embd && res->get_embd_pooled()) {
|
||||
t_embd = res->get_embd_pooled();
|
||||
}
|
||||
|
||||
// extract logits
|
||||
if (t_logits && n_outputs > 0) {
|
||||
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
|
||||
GGML_ASSERT(backend_res != nullptr);
|
||||
GGML_ASSERT(logits != nullptr);
|
||||
|
||||
float * logits_out = logits + n_outputs_prev*n_vocab;
|
||||
|
||||
if (n_outputs) {
|
||||
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
||||
GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size);
|
||||
ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float));
|
||||
std::unordered_map<llama_seq_id, int32_t> seq_to_idx;
|
||||
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
||||
if (ubatch.output[i]) {
|
||||
llama_seq_id seq_id = ubatch.seq_id[i][0];
|
||||
seq_to_idx[seq_id] = i;
|
||||
}
|
||||
}
|
||||
|
||||
// extract embeddings
|
||||
if (t_embd && n_outputs > 0) {
|
||||
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
||||
GGML_ASSERT(backend_embd != nullptr);
|
||||
// extract sampled tokens
|
||||
for (const auto & [seq_id, t_token] : res->t_sampled_tokens) {
|
||||
auto idx_it = seq_to_idx.find(seq_id);
|
||||
GGML_ASSERT(idx_it != seq_to_idx.end());
|
||||
const int32_t idx = idx_it->second;
|
||||
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched.get(), t_token);
|
||||
ggml_backend_tensor_get_async(backend, t_token, &sampled_tokens_map[idx], 0, sizeof(llama_token));
|
||||
}
|
||||
|
||||
switch (cparams.pooling_type) {
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
{
|
||||
// extract token embeddings
|
||||
GGML_ASSERT(embd != nullptr);
|
||||
float * embd_out = embd + n_outputs_prev*n_embd;
|
||||
for (const auto & [seq_id, t_ids] : res->t_sampled_token_ids) {
|
||||
auto idx_it = seq_to_idx.find(seq_id);
|
||||
GGML_ASSERT(idx_it != seq_to_idx.end());
|
||||
const int32_t idx = idx_it->second;
|
||||
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched.get(), t_ids);
|
||||
sampled_token_ids_map[idx].resize(ggml_nelements(t_ids));
|
||||
ggml_backend_tensor_get_async(backend, t_ids, sampled_token_ids_map[idx].data(), 0, ggml_nbytes(t_ids));
|
||||
}
|
||||
|
||||
if (n_outputs) {
|
||||
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
||||
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float));
|
||||
if (res->t_sampled_tokens.empty()) {
|
||||
for (const auto & [seq_id, t_logits] : res->t_sampled_logits) {
|
||||
auto idx_it = seq_to_idx.find(seq_id);
|
||||
GGML_ASSERT(idx_it != seq_to_idx.end());
|
||||
const int32_t idx = idx_it->second;
|
||||
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
|
||||
sampled_logits_map[idx].resize(ggml_nelements(t_logits));
|
||||
ggml_backend_tensor_get_async(backend, t_logits, sampled_logits_map[idx].data(), 0, ggml_nbytes(t_logits));
|
||||
}
|
||||
|
||||
// extract sampled probabilities
|
||||
for (const auto & [seq_id, t_probs] : res->t_sampled_probs) {
|
||||
auto idx_it = seq_to_idx.find(seq_id);
|
||||
GGML_ASSERT(idx_it != seq_to_idx.end());
|
||||
const int32_t idx = idx_it->second;
|
||||
ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched.get(), t_probs);
|
||||
sampled_probs_map[idx].resize(ggml_nelements(t_probs));
|
||||
ggml_backend_tensor_get_async(backend, t_probs, sampled_probs_map[idx].data(), 0, ggml_nbytes(t_probs));
|
||||
}
|
||||
}
|
||||
|
||||
backend_has_sampled = !res->t_sampled_tokens.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty();
|
||||
|
||||
if (!backend_has_sampled) {
|
||||
auto * t_logits = res->get_logits();
|
||||
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
||||
|
||||
if (t_embd && res->get_embd_pooled()) {
|
||||
t_embd = res->get_embd_pooled();
|
||||
}
|
||||
|
||||
// extract logits
|
||||
if (t_logits && n_outputs > 0) {
|
||||
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
|
||||
GGML_ASSERT(backend_res != nullptr);
|
||||
GGML_ASSERT(logits != nullptr);
|
||||
|
||||
float * logits_out = logits + n_outputs_prev*n_vocab;
|
||||
|
||||
if (n_outputs) {
|
||||
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
||||
GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size);
|
||||
ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
// extract embeddings
|
||||
if (t_embd && n_outputs > 0) {
|
||||
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
||||
GGML_ASSERT(backend_embd != nullptr);
|
||||
|
||||
switch (cparams.pooling_type) {
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
{
|
||||
// extract token embeddings
|
||||
GGML_ASSERT(embd != nullptr);
|
||||
float * embd_out = embd + n_outputs_prev*n_embd;
|
||||
|
||||
if (n_outputs) {
|
||||
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
||||
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
case LLAMA_POOLING_TYPE_LAST:
|
||||
{
|
||||
// extract sequence embeddings (cleared before processing each batch)
|
||||
auto & embd_seq_out = embd_seq;
|
||||
|
||||
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
||||
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
||||
|
||||
embd_seq_out[seq_id].resize(n_embd);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_RANK:
|
||||
{
|
||||
// extract the rerank score - n_cls_out floats per sequence
|
||||
auto & embd_seq_out = embd_seq;
|
||||
|
||||
const uint32_t n_cls_out = hparams.n_cls_out;
|
||||
|
||||
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
||||
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
||||
|
||||
embd_seq_out[seq_id].resize(n_cls_out);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||
{
|
||||
GGML_ABORT("unknown pooling type");
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
case LLAMA_POOLING_TYPE_LAST:
|
||||
{
|
||||
// extract sequence embeddings (cleared before processing each batch)
|
||||
auto & embd_seq_out = embd_seq;
|
||||
|
||||
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
||||
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
||||
|
||||
embd_seq_out[seq_id].resize(n_embd);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_RANK:
|
||||
{
|
||||
// extract the rerank score - n_cls_out floats per sequence
|
||||
auto & embd_seq_out = embd_seq;
|
||||
|
||||
const uint32_t n_cls_out = hparams.n_cls_out;
|
||||
|
||||
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
||||
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
||||
|
||||
embd_seq_out[seq_id].resize(n_cls_out);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||
{
|
||||
GGML_ABORT("unknown pooling type");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1231,7 +1431,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
n_outputs = n_outputs_all;
|
||||
|
||||
// set output mappings
|
||||
if (n_outputs > 0) {
|
||||
if (n_outputs > 0 && !backend_has_sampled) {
|
||||
bool sorted_output = true;
|
||||
|
||||
auto & out_ids = balloc->get_out_ids();
|
||||
|
|
@ -1345,9 +1545,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|||
}
|
||||
|
||||
float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get());
|
||||
llama_token * s_output_base = (llama_token *) ggml_backend_buffer_get_base(buf_output.get());
|
||||
|
||||
logits = has_logits ? output_base : nullptr;
|
||||
embd = has_embd ? output_base + logits_size : nullptr;
|
||||
logits = has_logits ? output_base : nullptr;
|
||||
embd = has_embd ? output_base + logits_size : nullptr;
|
||||
sampled_tokens = !samplers.empty() ? s_output_base : nullptr;
|
||||
sampled_probs = !samplers.empty() ? embd : nullptr;
|
||||
|
||||
// set all ids as invalid (negative)
|
||||
std::fill(output_ids.begin(), output_ids.end(), -1);
|
||||
|
|
@ -1456,6 +1659,7 @@ llm_graph_params llama_context::graph_params(
|
|||
/*.loras =*/ &loras,
|
||||
/*.mctx =*/ mctx,
|
||||
/*.cross =*/ &cross,
|
||||
/*.samplers =*/ samplers,
|
||||
/*.n_outputs =*/ n_outputs,
|
||||
/*.cb =*/ graph_get_cb(),
|
||||
/*.res =*/ res,
|
||||
|
|
@ -2319,6 +2523,8 @@ llama_context_params llama_context_default_params() {
|
|||
/*.op_offload =*/ true,
|
||||
/*.swa_full =*/ true,
|
||||
/*.kv_unified =*/ false,
|
||||
/*.sampler =*/ nullptr,
|
||||
/*.n_sampler =*/ 0,
|
||||
};
|
||||
|
||||
return result;
|
||||
|
|
@ -2478,6 +2684,13 @@ float * llama_get_logits(llama_context * ctx) {
|
|||
float * llama_get_logits_ith(llama_context * ctx, int32_t i) {
|
||||
ctx->synchronize();
|
||||
|
||||
if (ctx->get_backend_sampled_token_ith(i) != LLAMA_TOKEN_NULL) {
|
||||
return nullptr;
|
||||
}
|
||||
if (ctx->get_backend_sampled_probs_ith(i) != nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return ctx->get_logits_ith(i);
|
||||
}
|
||||
|
||||
|
|
@ -2499,6 +2712,46 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) {
|
|||
return ctx->get_embeddings_seq(seq_id);
|
||||
}
|
||||
|
||||
void llama_set_backend_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * sampler) {
|
||||
ctx->set_backend_sampler(seq_id, sampler);
|
||||
}
|
||||
|
||||
llama_token llama_get_backend_sampled_token_ith(llama_context * ctx, int32_t i) {
|
||||
ctx->synchronize();
|
||||
|
||||
return ctx->get_backend_sampled_token_ith(i);
|
||||
}
|
||||
|
||||
float * llama_get_backend_sampled_probs_ith(llama_context * ctx, int32_t i) {
|
||||
ctx->synchronize();
|
||||
|
||||
return ctx->get_backend_sampled_probs_ith(i);
|
||||
}
|
||||
|
||||
float * llama_get_backend_sampled_logits_ith(llama_context * ctx, int32_t i) {
|
||||
ctx->synchronize();
|
||||
|
||||
return ctx->get_backend_sampled_logits_ith(i);
|
||||
}
|
||||
|
||||
llama_token * llama_get_backend_sampled_token_ids_ith(llama_context * ctx, int32_t i) {
|
||||
ctx->synchronize();
|
||||
|
||||
return const_cast<llama_token *>(ctx->get_backend_sampled_token_ids_ith(i));
|
||||
}
|
||||
|
||||
uint32_t llama_get_backend_sampled_logits_count_ith(llama_context * ctx, int32_t i) {
|
||||
ctx->synchronize();
|
||||
|
||||
return static_cast<uint32_t>(ctx->get_backend_sampled_logits_count(i));
|
||||
}
|
||||
|
||||
uint32_t llama_get_backend_sampled_probs_count_ith(llama_context * ctx, int32_t i) {
|
||||
ctx->synchronize();
|
||||
|
||||
return static_cast<uint32_t>(ctx->get_backend_sampled_probs_count(i));
|
||||
}
|
||||
|
||||
// llama adapter API
|
||||
|
||||
int32_t llama_set_adapter_lora(
|
||||
|
|
|
|||
|
|
@ -66,6 +66,16 @@ struct llama_context {
|
|||
float * get_embeddings_ith(int32_t i);
|
||||
float * get_embeddings_seq(llama_seq_id seq_id);
|
||||
|
||||
llama_token * get_backend_sampled_tokens();
|
||||
llama_token get_backend_sampled_token_ith(int32_t idx);
|
||||
|
||||
float * get_backend_sampled_logits_ith(int32_t idx);
|
||||
const llama_token * get_backend_sampled_token_ids_ith(int32_t idx);
|
||||
size_t get_backend_sampled_logits_count(int32_t idx) const;
|
||||
|
||||
float * get_backend_sampled_probs_ith(int32_t idx);
|
||||
size_t get_backend_sampled_probs_count(int32_t idx) const;
|
||||
|
||||
void attach_threadpool(
|
||||
ggml_threadpool_t threadpool,
|
||||
ggml_threadpool_t threadpool_batch);
|
||||
|
|
@ -208,6 +218,8 @@ public:
|
|||
// reserve a graph with a dummy ubatch of the specified size
|
||||
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false);
|
||||
|
||||
void set_backend_sampler(llama_seq_id seq_id, llama_sampler * sampler);
|
||||
|
||||
private:
|
||||
llm_graph_params graph_params(
|
||||
llm_graph_result * res,
|
||||
|
|
@ -242,6 +254,16 @@ private:
|
|||
size_t logits_size = 0; // capacity (of floats) for logits
|
||||
float * logits = nullptr;
|
||||
|
||||
std::unordered_map<llama_seq_id, llama_sampler*> samplers;
|
||||
llama_token * sampled_tokens = nullptr;
|
||||
std::unordered_map<int32_t, llama_token> sampled_tokens_map;
|
||||
|
||||
float * sampled_probs = nullptr;
|
||||
std::unordered_map<int32_t, std::vector<float>> sampled_probs_map;
|
||||
|
||||
std::unordered_map<int32_t, std::vector<float>> sampled_logits_map;
|
||||
std::unordered_map<int32_t, std::vector<llama_token>> sampled_token_ids_map;
|
||||
|
||||
// embeddings output (2-dimensional array: [n_outputs][n_embd])
|
||||
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
||||
size_t embd_size = 0; // capacity (of floats) for embeddings
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
#include "llama-impl.h"
|
||||
#include "llama-batch.h"
|
||||
#include "llama-cparams.h"
|
||||
#include "llama-model.h"
|
||||
|
||||
#include "llama-kv-cache.h"
|
||||
#include "llama-kv-cache-iswa.h"
|
||||
|
|
@ -462,6 +463,28 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
|||
inp_rs->set_input(ubatch);
|
||||
}
|
||||
|
||||
void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) {
|
||||
GGML_UNUSED(ubatch);
|
||||
for (const auto & [seq_id, sampler] : samplers) {
|
||||
if (sampler->iface->set_input_ggml) {
|
||||
sampler->iface->set_input_ggml(sampler);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool llm_graph_input_sampling::can_reuse(const llm_graph_params & params) {
|
||||
if (params.samplers.empty()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
for (const auto & [seq_id, sampler] : params.samplers) {
|
||||
if (sampler_versions[seq_id] != llama_sampler_chain_get_version(sampler)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
//
|
||||
// llm_graph_result
|
||||
//
|
||||
|
|
@ -482,6 +505,10 @@ void llm_graph_result::reset() {
|
|||
t_logits = nullptr;
|
||||
t_embd = nullptr;
|
||||
t_embd_pooled = nullptr;
|
||||
t_sampled_tokens.clear();
|
||||
t_sampled_probs.clear();
|
||||
t_sampled_logits.clear();
|
||||
t_sampled_token_ids.clear();
|
||||
|
||||
params = {};
|
||||
|
||||
|
|
@ -587,6 +614,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
|||
loras (params.loras),
|
||||
mctx (params.mctx),
|
||||
cross (params.cross),
|
||||
samplers (params.samplers),
|
||||
cb_func (params.cb),
|
||||
res (params.res),
|
||||
ctx0 (res->get_ctx()),
|
||||
|
|
@ -2021,6 +2049,103 @@ void llm_graph_context::build_pooling(
|
|||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
void llm_graph_context::build_sampling(const llama_model & model, const llm_graph_params & params) const {
|
||||
GGML_UNUSED(params);
|
||||
if (samplers.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::unordered_map<llama_seq_id, int32_t> seq_to_logit_row;
|
||||
int32_t logit_row_idx = 0;
|
||||
|
||||
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
||||
if (ubatch.output[i]) {
|
||||
llama_seq_id seq_id = ubatch.seq_id[i][0];
|
||||
seq_to_logit_row[seq_id] = logit_row_idx;
|
||||
logit_row_idx++;
|
||||
}
|
||||
}
|
||||
if (seq_to_logit_row.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// res->t_logits will contain logits for all tokens that specied that want
|
||||
// logits calculated (logits=1 or output=1)
|
||||
ggml_tensor * logits_t = res->t_logits;
|
||||
GGML_ASSERT(res->t_logits != nullptr && "missing t_logits tensor");
|
||||
|
||||
const int n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(&model));
|
||||
GGML_ASSERT(logits_t->ne[0] == n_vocab);
|
||||
|
||||
ggml_backend_dev_t device = model.dev_output();
|
||||
ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(device);
|
||||
|
||||
std::unordered_map<llama_seq_id, llama_sampler*> active_samplers;
|
||||
|
||||
for (const auto & [seq_id, sampler] : samplers) {
|
||||
// Only process samplers for sequences that are in the current batch
|
||||
auto it = seq_to_logit_row.find(seq_id);
|
||||
if (it == seq_to_logit_row.end()) {
|
||||
continue;
|
||||
}
|
||||
const int32_t row_idx = it->second;
|
||||
|
||||
// Allow GPU sampler to create input tensors by implementing init_ggml.
|
||||
if (sampler->iface->init_ggml != nullptr) {
|
||||
sampler->iface->init_ggml(sampler, buft);
|
||||
}
|
||||
|
||||
active_samplers[seq_id] = sampler;
|
||||
|
||||
ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, n_vocab, row_idx * logits_t->nb[1]);
|
||||
ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
|
||||
|
||||
struct llama_sampler_ggml_data ggml_data = {
|
||||
/*.logits =*/ logits_seq,
|
||||
/*.probs =*/ nullptr,
|
||||
/*.sampled_token =*/ nullptr,
|
||||
/*.filtered_ids =*/ nullptr,
|
||||
};
|
||||
|
||||
llama_sampler_apply_ggml(sampler, ctx0, gf, &ggml_data);
|
||||
|
||||
if (ggml_data.sampled_token != nullptr) {
|
||||
res->t_sampled_tokens[seq_id] = ggml_data.sampled_token;
|
||||
ggml_build_forward_expand(gf, ggml_data.sampled_token);
|
||||
}
|
||||
|
||||
if (ggml_data.probs != nullptr) {
|
||||
res->t_sampled_probs[seq_id] = ggml_data.probs;
|
||||
ggml_build_forward_expand(gf, ggml_data.probs);
|
||||
}
|
||||
|
||||
if (ggml_data.logits != logits_seq) {
|
||||
res->t_sampled_logits[seq_id] = ggml_data.logits;
|
||||
ggml_build_forward_expand(gf, res->t_sampled_logits[seq_id]);
|
||||
}
|
||||
|
||||
if (ggml_data.filtered_ids != nullptr) {
|
||||
res->t_sampled_token_ids[seq_id] = ggml_data.filtered_ids;
|
||||
ggml_build_forward_expand(gf, ggml_data.filtered_ids);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Call llama_sampler_accept_ggml after all samplers have been applied.
|
||||
/*
|
||||
for (const auto & [seq_id, sampler] : samplers) {
|
||||
if (auto it = res->t_sampled_tokens.find(seq_id); it != res->t_sampled_tokens.end()) {
|
||||
ggml_tensor * selected_token = it->second;
|
||||
if (selected_token != nullptr) {
|
||||
llama_sampler_accept_ggml(sampler, ctx0, gf, selected_token);
|
||||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
auto inp_sampling = std::make_unique<llm_graph_input_sampling>(n_vocab, false, active_samplers);
|
||||
res->add_input(std::move(inp_sampling));
|
||||
}
|
||||
|
||||
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
|
||||
// TODO move to hparams if a T5 variant appears that uses a different value
|
||||
const int64_t max_distance = 128;
|
||||
|
|
|
|||
|
|
@ -383,6 +383,32 @@ public:
|
|||
const llama_memory_hybrid_context * mctx;
|
||||
};
|
||||
|
||||
class llm_graph_input_sampling : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_sampling(int32_t n_vocab, bool sorted,
|
||||
std::unordered_map<llama_seq_id, llama_sampler*> samplers) :
|
||||
n_vocab(n_vocab), sorted_value(sorted), samplers(samplers) {
|
||||
|
||||
sampler_versions.reserve(samplers.size());
|
||||
for (const auto & [seq_id, sampler] : samplers) {
|
||||
sampler_versions[seq_id] = llama_sampler_chain_get_version(sampler);
|
||||
}
|
||||
}
|
||||
virtual ~llm_graph_input_sampling() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
bool can_reuse(const llm_graph_params & params) override;
|
||||
|
||||
int32_t n_vocab;
|
||||
bool sorted_value;
|
||||
ggml_tensor * size = nullptr; // I32 [1]
|
||||
ggml_tensor * sorted = nullptr; // I32 [1]
|
||||
|
||||
// Track sampler chain version for reuse
|
||||
std::unordered_map<llama_seq_id, uint64_t> sampler_versions;
|
||||
std::unordered_map<llama_seq_id, llama_sampler*> samplers;
|
||||
};
|
||||
|
||||
//
|
||||
// llm_graph_result
|
||||
//
|
||||
|
|
@ -416,6 +442,23 @@ struct llm_graph_params {
|
|||
const llama_memory_context_i * mctx;
|
||||
const llama_cross * cross;
|
||||
|
||||
std::unordered_map<llama_seq_id, llama_sampler*> samplers;
|
||||
|
||||
static bool samplers_equal(
|
||||
const std::unordered_map<llama_seq_id, llama_sampler*> & lhs,
|
||||
const std::unordered_map<llama_seq_id, llama_sampler*> & rhs) {
|
||||
if (lhs.size() != rhs.size()) {
|
||||
return false;
|
||||
}
|
||||
for (const auto & [seq_id, sampler] : lhs) {
|
||||
auto it = rhs.find(seq_id);
|
||||
if (it == rhs.end() || it->second != sampler) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
uint32_t n_outputs;
|
||||
|
||||
llm_graph_cb cb;
|
||||
|
|
@ -463,7 +506,9 @@ struct llm_graph_params {
|
|||
cvec == other.cvec &&
|
||||
loras == other.loras &&
|
||||
cross == other.cross &&
|
||||
n_outputs == other.n_outputs;
|
||||
n_outputs == other.n_outputs &&
|
||||
samplers_equal(samplers, other.samplers);
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -504,6 +549,11 @@ public:
|
|||
ggml_tensor * t_embd = nullptr;
|
||||
ggml_tensor * t_embd_pooled = nullptr;
|
||||
|
||||
std::unordered_map<llama_seq_id, ggml_tensor*> t_sampled_logits;
|
||||
std::unordered_map<llama_seq_id, ggml_tensor*> t_sampled_token_ids;
|
||||
std::unordered_map<llama_seq_id, ggml_tensor*> t_sampled_tokens;
|
||||
std::unordered_map<llama_seq_id, ggml_tensor*> t_sampled_probs;
|
||||
|
||||
std::vector<llm_graph_input_ptr> inputs;
|
||||
|
||||
ggml_context_ptr ctx_compute;
|
||||
|
|
@ -579,6 +629,8 @@ struct llm_graph_context {
|
|||
const llama_memory_context_i * mctx;
|
||||
const llama_cross * cross;
|
||||
|
||||
std::unordered_map<llama_seq_id, llama_sampler*> samplers;
|
||||
|
||||
const llm_graph_cb & cb_func;
|
||||
|
||||
llm_graph_result * res;
|
||||
|
|
@ -819,6 +871,12 @@ struct llm_graph_context {
|
|||
ggml_tensor * cls_out,
|
||||
ggml_tensor * cls_out_b) const;
|
||||
|
||||
//
|
||||
// sampling (backend sampling)
|
||||
//
|
||||
|
||||
void build_sampling(const llama_model & model, const llm_graph_params & params) const;
|
||||
|
||||
//
|
||||
// dense (out)
|
||||
//
|
||||
|
|
|
|||
|
|
@ -7412,6 +7412,9 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
|||
// add on pooling layer
|
||||
llm->build_pooling(cls, cls_b, cls_out, cls_out_b);
|
||||
|
||||
// add backend sampling layers (if any)
|
||||
llm->build_sampling(*this, params);
|
||||
|
||||
// if the gguf model was converted with --sentence-transformers-dense-modules
|
||||
// there will be two additional dense projection layers
|
||||
// dense linear projections are applied after pooling
|
||||
|
|
|
|||
|
|
@ -372,6 +372,39 @@ void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_ar
|
|||
smpl->iface->apply(smpl, cur_p);
|
||||
}
|
||||
|
||||
void llama_sampler_apply_ggml(
|
||||
struct llama_sampler * smpl,
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * gf,
|
||||
struct llama_sampler_ggml_data * ggml_data) {
|
||||
GGML_ASSERT(smpl->iface->apply_ggml);
|
||||
smpl->iface->apply_ggml(smpl, ctx, gf, ggml_data);
|
||||
}
|
||||
|
||||
void llama_sampler_accept_ggml(
|
||||
struct llama_sampler * smpl,
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf,
|
||||
struct ggml_tensor * selected_token) {
|
||||
if (smpl->iface->accept_ggml) {
|
||||
smpl->iface->accept_ggml(smpl, ctx, gf, selected_token);
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sampler_set_input_ggml(struct llama_sampler * smpl) {
|
||||
if (smpl->iface->set_input_ggml) {
|
||||
smpl->iface->set_input_ggml(smpl);
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sampler_init_ggml(
|
||||
struct llama_sampler * smpl,
|
||||
ggml_backend_buffer_type_t buft) {
|
||||
if (smpl->iface->init_ggml) {
|
||||
smpl->iface->init_ggml(smpl, buft);
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sampler_reset(struct llama_sampler * smpl) {
|
||||
if (smpl->iface->reset) {
|
||||
smpl->iface->reset(smpl);
|
||||
|
|
@ -406,7 +439,15 @@ void llama_sampler_free(struct llama_sampler * smpl) {
|
|||
}
|
||||
|
||||
llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
|
||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||
const llama_token sampled_token = llama_get_backend_sampled_token_ith(ctx, idx);
|
||||
const float * sampled_probs = llama_get_backend_sampled_probs_ith(ctx, idx);
|
||||
const float * sampled_logits = llama_get_backend_sampled_logits_ith(ctx, idx);
|
||||
const llama_token * sampled_ids = llama_get_backend_sampled_token_ids_ith(ctx, idx);
|
||||
|
||||
// If a backend sampler has already sampled a token, return it.
|
||||
if (sampled_token != LLAMA_TOKEN_NULL) {
|
||||
return sampled_token;
|
||||
}
|
||||
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
|
@ -415,9 +456,40 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte
|
|||
|
||||
// TODO: do not allocate each time
|
||||
std::vector<llama_token_data> cur;
|
||||
cur.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||
|
||||
if (sampled_probs) {
|
||||
const uint32_t sampled_probs_count = llama_get_backend_sampled_probs_count_ith(ctx, idx);
|
||||
cur.reserve(sampled_probs_count);
|
||||
// The backend sampler has filtered the probabilities so we need to use the sampled ids.
|
||||
if (sampled_ids != nullptr) {
|
||||
for (uint32_t i = 0; i < sampled_probs_count; ++i) {
|
||||
cur.emplace_back(llama_token_data{sampled_ids[i], 0.0f, sampled_probs[i]});
|
||||
}
|
||||
} else {
|
||||
for (llama_token token_id = 0; token_id < (int) sampled_probs_count; token_id++) {
|
||||
cur.emplace_back(llama_token_data{token_id, 0.0f, sampled_probs[token_id]});
|
||||
}
|
||||
}
|
||||
} else if (sampled_logits) {
|
||||
const uint32_t sampled_logits_count = llama_get_backend_sampled_logits_count_ith(ctx, idx);
|
||||
cur.reserve(sampled_logits_count);
|
||||
// The backend sampler has filtered the logits so we need to use the sampled ids.
|
||||
if (sampled_ids != nullptr) {
|
||||
for (llama_token i = 0; i < (int)sampled_logits_count; i++) {
|
||||
cur.emplace_back(llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f});
|
||||
}
|
||||
} else {
|
||||
for (llama_token token_id = 0; token_id < (int)sampled_logits_count; token_id++) {
|
||||
cur.emplace_back(llama_token_data{token_id, sampled_logits[token_id], 0.0f});
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||
GGML_ASSERT(logits != nullptr);
|
||||
cur.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||
}
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = {
|
||||
|
|
@ -462,6 +534,10 @@ static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_d
|
|||
time_meas tm(chain->t_sample_us, chain->params.no_perf);
|
||||
|
||||
for (auto * smpl : chain->samplers) {
|
||||
// Skip GPU samplers - they have apply_ggml but no apply
|
||||
if (smpl->iface->apply == nullptr) {
|
||||
continue;
|
||||
}
|
||||
llama_sampler_apply(smpl, cur_p);
|
||||
}
|
||||
}
|
||||
|
|
@ -499,13 +575,67 @@ static void llama_sampler_chain_free(struct llama_sampler * smpl) {
|
|||
delete chain;
|
||||
}
|
||||
|
||||
static void llama_sampler_chain_apply_ggml(
|
||||
struct llama_sampler * smpl,
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * gf,
|
||||
struct llama_sampler_ggml_data * ggml_data) {
|
||||
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
||||
|
||||
for (auto * smpl : chain->samplers) {
|
||||
if (smpl->iface->apply_ggml) {
|
||||
smpl->iface->apply_ggml(smpl, ctx, gf, ggml_data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void llama_sampler_chain_accept_ggml(
|
||||
struct llama_sampler * smpl,
|
||||
ggml_context * ctx,
|
||||
ggml_cgraph * gf,
|
||||
struct ggml_tensor * selected_token) {
|
||||
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
||||
|
||||
for (auto * smpl : chain->samplers) {
|
||||
if (smpl->iface->accept_ggml) {
|
||||
smpl->iface->accept_ggml(smpl, ctx, gf, selected_token);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void llama_sampler_chain_set_input_ggml(struct llama_sampler * smpl) {
|
||||
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
||||
|
||||
for (auto * smpl : chain->samplers) {
|
||||
if (smpl->iface->set_input_ggml) {
|
||||
smpl->iface->set_input_ggml(smpl);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void llama_sampler_chain_set_backend_context(
|
||||
struct llama_sampler * smpl,
|
||||
ggml_backend_buffer_type_t buft) {
|
||||
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
||||
|
||||
for (auto * smpl : chain->samplers) {
|
||||
if (smpl->iface->init_ggml) {
|
||||
smpl->iface->init_ggml(smpl,buft);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static struct llama_sampler_i llama_sampler_chain_i = {
|
||||
/* .name = */ llama_sampler_chain_name,
|
||||
/* .accept = */ llama_sampler_chain_accept,
|
||||
/* .apply = */ llama_sampler_chain_apply,
|
||||
/* .reset = */ llama_sampler_chain_reset,
|
||||
/* .clone = */ llama_sampler_chain_clone,
|
||||
/* .free = */ llama_sampler_chain_free,
|
||||
/* .name = */ llama_sampler_chain_name,
|
||||
/* .accept = */ llama_sampler_chain_accept,
|
||||
/* .apply = */ llama_sampler_chain_apply,
|
||||
/* .reset = */ llama_sampler_chain_reset,
|
||||
/* .clone = */ llama_sampler_chain_clone,
|
||||
/* .free = */ llama_sampler_chain_free,
|
||||
/* .apply_ggml = */ llama_sampler_chain_apply_ggml,
|
||||
/* .accept_ggml = */ llama_sampler_chain_accept_ggml,
|
||||
/* .set_input_ggml = */ llama_sampler_chain_set_input_ggml,
|
||||
/* .set_backend_context = */ llama_sampler_chain_set_backend_context,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
|
||||
|
|
@ -523,6 +653,7 @@ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_param
|
|||
void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
|
||||
auto * p = (llama_sampler_chain *) chain->ctx;
|
||||
p->samplers.push_back(smpl);
|
||||
p->version++;
|
||||
}
|
||||
|
||||
struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
|
||||
|
|
@ -544,6 +675,7 @@ struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain,
|
|||
|
||||
auto * result = p->samplers[i];
|
||||
p->samplers.erase(p->samplers.begin() + i);
|
||||
p->version++;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
|
@ -554,6 +686,11 @@ int llama_sampler_chain_n(const struct llama_sampler * chain) {
|
|||
return p->samplers.size();
|
||||
}
|
||||
|
||||
uint64_t llama_sampler_chain_get_version(const struct llama_sampler * chain) {
|
||||
const auto * p = (const llama_sampler_chain *) chain->ctx;
|
||||
return p->version;
|
||||
}
|
||||
|
||||
//
|
||||
// samplers
|
||||
//
|
||||
|
|
@ -574,12 +711,16 @@ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_to
|
|||
}
|
||||
|
||||
static struct llama_sampler_i llama_sampler_greedy_i = {
|
||||
/* .name = */ llama_sampler_greedy_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_greedy_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ nullptr,
|
||||
/* .free = */ nullptr,
|
||||
/* .name = */ llama_sampler_greedy_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_greedy_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ nullptr,
|
||||
/* .free = */ nullptr,
|
||||
/* .apply_ggml = */ nullptr,
|
||||
/* .accept_ggml = */ nullptr,
|
||||
/* .set_input_ggml = */ nullptr,
|
||||
/* .set_backend_context = */ nullptr,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_greedy() {
|
||||
|
|
@ -699,12 +840,16 @@ static void llama_sampler_dist_free(struct llama_sampler * smpl) {
|
|||
}
|
||||
|
||||
static struct llama_sampler_i llama_sampler_dist_i = {
|
||||
/* .name = */ llama_sampler_dist_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_dist_apply,
|
||||
/* .reset = */ llama_sampler_dist_reset,
|
||||
/* .clone = */ llama_sampler_dist_clone,
|
||||
/* .free = */ llama_sampler_dist_free,
|
||||
/* .name = */ llama_sampler_dist_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_dist_apply,
|
||||
/* .reset = */ llama_sampler_dist_reset,
|
||||
/* .clone = */ llama_sampler_dist_clone,
|
||||
/* .free = */ llama_sampler_dist_free,
|
||||
/* .apply_ggml = */ nullptr,
|
||||
/* .accept_ggml = */ nullptr,
|
||||
/* .set_input_ggml = */ nullptr,
|
||||
/* .set_backend_context = */ nullptr,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
|
||||
|
|
@ -744,12 +889,16 @@ static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
|
|||
}
|
||||
|
||||
static struct llama_sampler_i llama_sampler_top_k_i = {
|
||||
/* .name = */ llama_sampler_top_k_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_top_k_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ llama_sampler_top_k_clone,
|
||||
/* .free = */ llama_sampler_top_k_free,
|
||||
/* .name = */ llama_sampler_top_k_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_top_k_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ llama_sampler_top_k_clone,
|
||||
/* .free = */ llama_sampler_top_k_free,
|
||||
/* .apply_ggml = */ nullptr,
|
||||
/* .accept_ggml = */ nullptr,
|
||||
/* .set_input_ggml = */ nullptr,
|
||||
/* .set_backend_context = */ nullptr,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
|
||||
|
|
@ -839,12 +988,16 @@ static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
|
|||
}
|
||||
|
||||
static struct llama_sampler_i llama_sampler_top_p_i = {
|
||||
/* .name = */ llama_sampler_top_p_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_top_p_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ llama_sampler_top_p_clone,
|
||||
/* .free = */ llama_sampler_top_p_free,
|
||||
/* .name = */ llama_sampler_top_p_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_top_p_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ llama_sampler_top_p_clone,
|
||||
/* .free = */ llama_sampler_top_p_free,
|
||||
/* .apply_ggml = */ nullptr,
|
||||
/* .accept_ggml = */ nullptr,
|
||||
/* .set_input_ggml = */ nullptr,
|
||||
/* .set_backend_context = */ nullptr,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
|
||||
|
|
@ -933,12 +1086,16 @@ static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
|
|||
}
|
||||
|
||||
static struct llama_sampler_i llama_sampler_min_p_i = {
|
||||
/* .name = */ llama_sampler_min_p_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_min_p_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ llama_sampler_min_p_clone,
|
||||
/* .free = */ llama_sampler_min_p_free,
|
||||
/* .name = */ llama_sampler_min_p_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_min_p_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ llama_sampler_min_p_clone,
|
||||
/* .free = */ llama_sampler_min_p_free,
|
||||
/* .apply_ggml = */ nullptr,
|
||||
/* .accept_ggml = */ nullptr,
|
||||
/* .set_input_ggml = */ nullptr,
|
||||
/* .set_backend_context = */ nullptr,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
|
||||
|
|
@ -1032,12 +1189,16 @@ static void llama_sampler_typical_free(struct llama_sampler * smpl) {
|
|||
}
|
||||
|
||||
static struct llama_sampler_i llama_sampler_typical_i = {
|
||||
/* .name = */ llama_sampler_typical_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_typical_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ llama_sampler_typical_clone,
|
||||
/* .free = */ llama_sampler_typical_free,
|
||||
/* .name = */ llama_sampler_typical_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_typical_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ llama_sampler_typical_clone,
|
||||
/* .free = */ llama_sampler_typical_free,
|
||||
/* .apply_ggml = */ nullptr,
|
||||
/* .accept_ggml = */ nullptr,
|
||||
/* .set_input_ggml = */ nullptr,
|
||||
/* .set_backend_context = */ nullptr,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
|
||||
|
|
@ -1076,12 +1237,16 @@ static void llama_sampler_temp_free(struct llama_sampler * smpl) {
|
|||
}
|
||||
|
||||
static struct llama_sampler_i llama_sampler_temp_i = {
|
||||
/* .name = */ llama_sampler_temp_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_temp_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ llama_sampler_temp_clone,
|
||||
/* .free = */ llama_sampler_temp_free,
|
||||
/* .name = */ llama_sampler_temp_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_temp_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ llama_sampler_temp_clone,
|
||||
/* .free = */ llama_sampler_temp_free,
|
||||
/* .apply_ggml = */ nullptr,
|
||||
/* .accept_ggml = */ nullptr,
|
||||
/* .set_input_ggml = */ nullptr,
|
||||
/* .set_backend_context = */ nullptr,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_temp(float temp) {
|
||||
|
|
@ -1186,12 +1351,16 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
|
|||
}
|
||||
|
||||
static struct llama_sampler_i llama_sampler_temp_ext_i = {
|
||||
/* .name = */ llama_sampler_temp_ext_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_temp_ext_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ llama_sampler_temp_ext_clone,
|
||||
/* .free = */ llama_sampler_temp_ext_free,
|
||||
/* .name = */ llama_sampler_temp_ext_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_temp_ext_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ llama_sampler_temp_ext_clone,
|
||||
/* .free = */ llama_sampler_temp_ext_free,
|
||||
/* .apply_ggml = */ nullptr,
|
||||
/* .accept_ggml = */ nullptr,
|
||||
/* .set_input_ggml = */ nullptr,
|
||||
/* .set_backend_context = */ nullptr,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
|
||||
|
|
@ -1280,12 +1449,16 @@ static void llama_sampler_xtc_reset(struct llama_sampler * smpl) {
|
|||
}
|
||||
|
||||
static struct llama_sampler_i llama_sampler_xtc_i = {
|
||||
/* .name = */ llama_sampler_xtc_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sample_xtc_apply,
|
||||
/* .reset = */ llama_sampler_xtc_reset,
|
||||
/* .clone = */ llama_sampler_xtc_clone,
|
||||
/* .free = */ llama_sampler_xtc_free,
|
||||
/* .name = */ llama_sampler_xtc_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sample_xtc_apply,
|
||||
/* .reset = */ llama_sampler_xtc_reset,
|
||||
/* .clone = */ llama_sampler_xtc_clone,
|
||||
/* .free = */ llama_sampler_xtc_free,
|
||||
/* .apply_ggml = */ nullptr,
|
||||
/* .accept_ggml = */ nullptr,
|
||||
/* .set_input_ggml = */ nullptr,
|
||||
/* .set_backend_context = */ nullptr,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) {
|
||||
|
|
@ -1388,12 +1561,16 @@ static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
|
|||
}
|
||||
|
||||
static struct llama_sampler_i llama_sampler_mirostat_i = {
|
||||
/* .name = */ llama_sampler_mirostat_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_mirostat_apply,
|
||||
/* .reset = */ llama_sampler_mirostat_reset,
|
||||
/* .clone = */ llama_sampler_mirostat_clone,
|
||||
/* .free = */ llama_sampler_mirostat_free,
|
||||
/* .name = */ llama_sampler_mirostat_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_mirostat_apply,
|
||||
/* .reset = */ llama_sampler_mirostat_reset,
|
||||
/* .clone = */ llama_sampler_mirostat_clone,
|
||||
/* .free = */ llama_sampler_mirostat_free,
|
||||
/* .apply_ggml = */ nullptr,
|
||||
/* .accept_ggml = */ nullptr,
|
||||
/* .set_input_ggml = */ nullptr,
|
||||
/* .set_backend_context = */ nullptr,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
|
||||
|
|
@ -1487,12 +1664,16 @@ static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) {
|
|||
}
|
||||
|
||||
static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
|
||||
/* .name = */ llama_sampler_mirostat_v2_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_mirostat_v2_apply,
|
||||
/* .reset = */ llama_sampler_mirostat_v2_reset,
|
||||
/* .clone = */ llama_sampler_mirostat_v2_clone,
|
||||
/* .free = */ llama_sampler_mirostat_v2_free,
|
||||
/* .name = */ llama_sampler_mirostat_v2_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_mirostat_v2_apply,
|
||||
/* .reset = */ llama_sampler_mirostat_v2_reset,
|
||||
/* .clone = */ llama_sampler_mirostat_v2_clone,
|
||||
/* .free = */ llama_sampler_mirostat_v2_free,
|
||||
/* .apply_ggml = */ nullptr,
|
||||
/* .accept_ggml = */ nullptr,
|
||||
/* .set_input_ggml = */ nullptr,
|
||||
/* .set_backend_context = */ nullptr,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
|
||||
|
|
@ -1604,12 +1785,16 @@ static void llama_sampler_grammar_free(struct llama_sampler * smpl) {
|
|||
}
|
||||
|
||||
static struct llama_sampler_i llama_sampler_grammar_i = {
|
||||
/* .name = */ llama_sampler_grammar_name,
|
||||
/* .accept = */ llama_sampler_grammar_accept_impl,
|
||||
/* .apply = */ llama_sampler_grammar_apply,
|
||||
/* .reset = */ llama_sampler_grammar_reset,
|
||||
/* .clone = */ llama_sampler_grammar_clone,
|
||||
/* .free = */ llama_sampler_grammar_free,
|
||||
/* .name = */ llama_sampler_grammar_name,
|
||||
/* .accept = */ llama_sampler_grammar_accept_impl,
|
||||
/* .apply = */ llama_sampler_grammar_apply,
|
||||
/* .reset = */ llama_sampler_grammar_reset,
|
||||
/* .clone = */ llama_sampler_grammar_clone,
|
||||
/* .free = */ llama_sampler_grammar_free,
|
||||
/* .apply_ggml = */ nullptr,
|
||||
/* .accept_ggml = */ nullptr,
|
||||
/* .set_input_ggml = */ nullptr,
|
||||
/* .set_backend_context = */ nullptr,
|
||||
};
|
||||
|
||||
static struct llama_sampler * llama_sampler_init_grammar_impl(
|
||||
|
|
@ -1811,12 +1996,16 @@ static void llama_sampler_penalties_free(struct llama_sampler * smpl) {
|
|||
}
|
||||
|
||||
static struct llama_sampler_i llama_sampler_penalties_i = {
|
||||
/* .name = */ llama_sampler_penalties_name,
|
||||
/* .accept = */ llama_sampler_penalties_accept,
|
||||
/* .apply = */ llama_sampler_penalties_apply,
|
||||
/* .reset = */ llama_sampler_penalties_reset,
|
||||
/* .clone = */ llama_sampler_penalties_clone,
|
||||
/* .free = */ llama_sampler_penalties_free,
|
||||
/* .name = */ llama_sampler_penalties_name,
|
||||
/* .accept = */ llama_sampler_penalties_accept,
|
||||
/* .apply = */ llama_sampler_penalties_apply,
|
||||
/* .reset = */ llama_sampler_penalties_reset,
|
||||
/* .clone = */ llama_sampler_penalties_clone,
|
||||
/* .free = */ llama_sampler_penalties_free,
|
||||
/* .apply_ggml = */ nullptr,
|
||||
/* .accept_ggml = */ nullptr,
|
||||
/* .set_input_ggml = */ nullptr,
|
||||
/* .set_backend_context = */ nullptr,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_penalties(
|
||||
|
|
@ -1902,12 +2091,16 @@ static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) {
|
|||
}
|
||||
|
||||
static struct llama_sampler_i llama_sampler_top_n_sigma_i = {
|
||||
/* .name = */ llama_sampler_top_n_sigma_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_top_n_sigma_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ llama_sampler_top_n_sigma_clone,
|
||||
/* .free = */ llama_sampler_top_n_sigma_free,
|
||||
/* .name = */ llama_sampler_top_n_sigma_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_top_n_sigma_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ llama_sampler_top_n_sigma_clone,
|
||||
/* .free = */ llama_sampler_top_n_sigma_free,
|
||||
/* .apply_ggml = */ nullptr,
|
||||
/* .accept_ggml = */ nullptr,
|
||||
/* .set_input_ggml = */ nullptr,
|
||||
/* .set_backend_context = */ nullptr,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
|
||||
|
|
@ -2232,12 +2425,16 @@ static void llama_sampler_dry_free(struct llama_sampler * smpl) {
|
|||
}
|
||||
|
||||
static struct llama_sampler_i llama_sampler_dry_i = {
|
||||
/* .name = */ llama_sampler_dry_name,
|
||||
/* .accept = */ llama_sampler_dry_accept,
|
||||
/* .apply = */ llama_sampler_dry_apply,
|
||||
/* .reset = */ llama_sampler_dry_reset,
|
||||
/* .clone = */ llama_sampler_dry_clone,
|
||||
/* .free = */ llama_sampler_dry_free,
|
||||
/* .name = */ llama_sampler_dry_name,
|
||||
/* .accept = */ llama_sampler_dry_accept,
|
||||
/* .apply = */ llama_sampler_dry_apply,
|
||||
/* .reset = */ llama_sampler_dry_reset,
|
||||
/* .clone = */ llama_sampler_dry_clone,
|
||||
/* .free = */ llama_sampler_dry_free,
|
||||
/* .apply_ggml = */ nullptr,
|
||||
/* .accept_ggml = */ nullptr,
|
||||
/* .set_input_ggml = */ nullptr,
|
||||
/* .set_backend_context = */ nullptr,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t n_ctx_train, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
|
||||
|
|
@ -2373,12 +2570,16 @@ static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) {
|
|||
}
|
||||
|
||||
static struct llama_sampler_i llama_sampler_logit_bias_i = {
|
||||
/* .name = */ llama_sampler_logit_bias_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_logit_bias_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ llama_sampler_logit_bias_clone,
|
||||
/* .free = */ llama_sampler_logit_bias_free,
|
||||
/* .name = */ llama_sampler_logit_bias_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_logit_bias_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ llama_sampler_logit_bias_clone,
|
||||
/* .free = */ llama_sampler_logit_bias_free,
|
||||
/* .apply_ggml = */ nullptr,
|
||||
/* .accept_ggml = */ nullptr,
|
||||
/* .set_input_ggml = */ nullptr,
|
||||
/* .set_backend_context = */ nullptr,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_logit_bias(
|
||||
|
|
@ -2603,12 +2804,16 @@ static void llama_sampler_infill_free(struct llama_sampler * smpl) {
|
|||
}
|
||||
|
||||
static struct llama_sampler_i llama_sampler_infill_i = {
|
||||
/* .name = */ llama_sampler_infill_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_infill_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ llama_sampler_infill_clone,
|
||||
/* .free = */ llama_sampler_infill_free,
|
||||
/* .name = */ llama_sampler_infill_name,
|
||||
/* .accept = */ nullptr,
|
||||
/* .apply = */ llama_sampler_infill_apply,
|
||||
/* .reset = */ nullptr,
|
||||
/* .clone = */ llama_sampler_infill_clone,
|
||||
/* .free = */ llama_sampler_infill_free,
|
||||
/* .apply_ggml = */ nullptr,
|
||||
/* .accept_ggml = */ nullptr,
|
||||
/* .set_input_ggml = */ nullptr,
|
||||
/* .set_backend_context = */ nullptr,
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
|
||||
|
|
|
|||
|
|
@ -21,6 +21,9 @@ struct llama_sampler_chain {
|
|||
mutable int64_t t_sample_us;
|
||||
|
||||
mutable int32_t n_sample;
|
||||
|
||||
// simple version tracking for GPU sampling graph can_reuse
|
||||
uint64_t version = 0;
|
||||
};
|
||||
|
||||
struct llama_sampler * llama_sampler_init_dry_testing(
|
||||
|
|
|
|||
|
|
@ -206,6 +206,18 @@ llama_build_and_test(test-backend-ops.cpp)
|
|||
llama_build_and_test(test-model-load-cancel.cpp LABEL "model")
|
||||
llama_build_and_test(test-autorelease.cpp LABEL "model")
|
||||
|
||||
llama_build_and_test(test-backend-sampler.cpp LABEL "model")
|
||||
target_include_directories(test-backend-sampler PRIVATE ${PROJECT_SOURCE_DIR}/src)
|
||||
llama_test(test-backend-sampler NAME test-backend-sampler-greedy ARGS --test greedy)
|
||||
llama_test(test-backend-sampler NAME test-backend-sampler-temp ARGS --test temp)
|
||||
llama_test(test-backend-sampler NAME test-backend-sampler-top_k ARGS --test top_k)
|
||||
llama_test(test-backend-sampler NAME test-backend-sampler-dist ARGS --test dist)
|
||||
llama_test(test-backend-sampler NAME test-backend-sampler-dist-and-cpu ARGS --test dist_and_cpu)
|
||||
llama_test(test-backend-sampler NAME test-backend-sampler-logit-bias ARGS --test logit_bias)
|
||||
llama_test(test-backend-sampler NAME test-backend-sampler-mul_seq ARGS --test multi_sequence)
|
||||
llama_test(test-backend-sampler NAME test-backend-sampler-set-sampler ARGS --test set_sampler)
|
||||
|
||||
|
||||
if (NOT GGML_BACKEND_DL)
|
||||
# these tests use the backends directly and cannot be built with dynamic loading
|
||||
llama_build_and_test(test-barrier.cpp)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,760 @@
|
|||
#include "ggml.h"
|
||||
#include "llama.h"
|
||||
#include "get-model.h"
|
||||
#include "common.h"
|
||||
|
||||
#ifdef NDEBUG
|
||||
#undef NDEBUG
|
||||
#endif
|
||||
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <array>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
struct test_model_context {
|
||||
llama_model * model = nullptr;
|
||||
llama_context * ctx = nullptr;
|
||||
const llama_vocab * vocab = nullptr;
|
||||
int n_vocab = 0;
|
||||
std::unordered_map<llama_seq_id, int32_t> seq_positions;
|
||||
std::unordered_map<llama_seq_id, int32_t> last_batch_info;
|
||||
|
||||
bool setup_model(const char * model_path) {
|
||||
if (model != nullptr) {
|
||||
return true;
|
||||
}
|
||||
|
||||
llama_backend_init();
|
||||
|
||||
llama_model_params mparams = llama_model_default_params();
|
||||
model = llama_model_load_from_file(model_path, mparams);
|
||||
if (model == nullptr) {
|
||||
fprintf(stderr, "Warning: failed to load model '%s', skipping test\n", model_path);
|
||||
cleanup();
|
||||
return false;
|
||||
}
|
||||
vocab = llama_model_get_vocab(model);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool setup(const char * model_path, std::vector<llama_sampler_seq_config> & configs) {
|
||||
if (model == nullptr) {
|
||||
setup_model(model_path);
|
||||
}
|
||||
|
||||
if (model != nullptr && ctx != nullptr) {
|
||||
return true;
|
||||
}
|
||||
|
||||
llama_context_params cparams = llama_context_default_params();
|
||||
cparams.n_ctx = 512;
|
||||
cparams.n_batch = 512;
|
||||
cparams.samplers = configs.data();
|
||||
cparams.n_samplers = configs.size();
|
||||
|
||||
int32_t max_seq_id = 0;
|
||||
for (const auto & config : configs) {
|
||||
if (config.seq_id > max_seq_id) {
|
||||
max_seq_id = config.seq_id;
|
||||
}
|
||||
}
|
||||
cparams.n_seq_max = max_seq_id + 1;
|
||||
|
||||
ctx = llama_init_from_model(model, cparams);
|
||||
if (ctx == nullptr) {
|
||||
fprintf(stderr, "Warning: failed to create context, skipping test\n");
|
||||
cleanup();
|
||||
return false;
|
||||
}
|
||||
llama_set_warmup(ctx, false);
|
||||
|
||||
vocab = llama_model_get_vocab(model);
|
||||
n_vocab = llama_vocab_n_tokens(vocab);
|
||||
fprintf(stderr, "Vocabulary size: %d\n", n_vocab);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool decode(const std::map<llama_seq_id, std::string> & prompts) {
|
||||
if (ctx == nullptr || vocab == nullptr) {
|
||||
fprintf(stderr, "Error: context not initialized, call setup() first\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
last_batch_info.clear();
|
||||
llama_batch batch = llama_batch_init(512, 0, prompts.size());
|
||||
|
||||
int n_tokens_per_prompt = 0;
|
||||
|
||||
for (const auto & [seq_id, prompt] : prompts) {
|
||||
std::vector<llama_token> tokens;
|
||||
tokens.push_back(llama_vocab_bos(vocab));
|
||||
|
||||
std::vector<llama_token> prompt_tokens(32);
|
||||
int n_tokens = llama_tokenize(vocab, prompt.c_str(), prompt.length(),
|
||||
prompt_tokens.data(), prompt_tokens.size(),
|
||||
false, false);
|
||||
//TODO: refactor this function to just handle a single prompt at a time
|
||||
// to avoid this check and complexity.
|
||||
if (n_tokens_per_prompt == 0) {
|
||||
n_tokens_per_prompt = n_tokens;
|
||||
} else {
|
||||
if (n_tokens != n_tokens_per_prompt) {
|
||||
fprintf(stderr, "Error: prompts must have the same number of tokens\n");
|
||||
llama_batch_free(batch);
|
||||
return false;
|
||||
}
|
||||
n_tokens_per_prompt = n_tokens;
|
||||
}
|
||||
if (n_tokens < 0) {
|
||||
fprintf(stderr, "Warning: tokenization failed for seq_id %d\n", seq_id);
|
||||
llama_batch_free(batch);
|
||||
return false;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_tokens; i++) {
|
||||
tokens.push_back(prompt_tokens[i]);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < tokens.size(); i++) {
|
||||
common_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
|
||||
}
|
||||
|
||||
seq_positions[seq_id] = tokens.size();
|
||||
}
|
||||
|
||||
|
||||
printf("Batch contents:\n");
|
||||
printf(" n_tokens: %d\n", batch.n_tokens);
|
||||
for (int i = 0; i < batch.n_tokens; i++) {
|
||||
printf(" token[%d]: tok=%-5d, pos=%d, n_seq_id=%d, seq_ids=[", i, batch.token[i], batch.pos[i], batch.n_seq_id[i]);
|
||||
|
||||
for (int j = 0; j < batch.n_seq_id[i]; j++) {
|
||||
printf("%d%s", batch.seq_id[i][j], j < batch.n_seq_id[i]-1 ? ", " : "");
|
||||
}
|
||||
printf("], logits=%d\n", batch.logits[i]);
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
fprintf(stderr, "Warning: llama_decode failed\n");
|
||||
llama_batch_free(batch);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Build mapping from seq id to batch token idx
|
||||
for (int i = 0; i < batch.n_tokens; i++) {
|
||||
if (batch.logits[i]) {
|
||||
llama_seq_id seq_id = batch.seq_id[i][0];
|
||||
last_batch_info[seq_id] = i;
|
||||
printf("seq %d : batch idx %d\n", seq_id, i);
|
||||
}
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
return true;
|
||||
}
|
||||
|
||||
int32_t idx_for_seq(llama_seq_id seq_id) {
|
||||
auto it = last_batch_info.find(seq_id);
|
||||
if (it == last_batch_info.end()) {
|
||||
fprintf(stderr, "Error: no batch index found for seq_id %d\n", seq_id);
|
||||
return -1;
|
||||
}
|
||||
return it->second;
|
||||
}
|
||||
|
||||
bool decode_token(llama_token token, llama_seq_id seq_id = 0) {
|
||||
if (ctx == nullptr) {
|
||||
fprintf(stderr, "Error: context not initialized, call setup() first\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
llama_batch batch = llama_batch_init(1, 0, 1);
|
||||
int32_t pos = seq_positions[seq_id];
|
||||
common_batch_add(batch, token, pos, { seq_id }, true);
|
||||
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
fprintf(stderr, "Warning: llama_decode failed for token %d in seq %d\n", token, seq_id);
|
||||
llama_batch_free(batch);
|
||||
return false;
|
||||
}
|
||||
|
||||
last_batch_info.clear();
|
||||
for (int i = 0; i < batch.n_tokens; i++) {
|
||||
if (batch.logits[i]) {
|
||||
llama_seq_id cur_seq = batch.seq_id[i][0];
|
||||
last_batch_info[cur_seq] = i;
|
||||
}
|
||||
}
|
||||
|
||||
seq_positions[seq_id]++;
|
||||
llama_batch_free(batch);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool decode_tokens(const std::map<llama_seq_id, llama_token> & seq_tokens) {
|
||||
if (ctx == nullptr) {
|
||||
fprintf(stderr, "Error: context not initialized, call setup() first\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
llama_batch batch = llama_batch_init(seq_tokens.size(), 0, seq_tokens.size());
|
||||
|
||||
for (const auto & [seq_id, token] : seq_tokens) {
|
||||
int32_t pos = seq_positions[seq_id];
|
||||
common_batch_add(batch, token, pos, { seq_id }, true);
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
fprintf(stderr, "Warning: llama_decode failed for batch tokens\n");
|
||||
llama_batch_free(batch);
|
||||
return false;
|
||||
}
|
||||
|
||||
for (const auto & [seq_id, _] : seq_tokens) {
|
||||
seq_positions[seq_id]++;
|
||||
}
|
||||
|
||||
last_batch_info.clear();
|
||||
for (int i = 0; i < batch.n_tokens; i++) {
|
||||
if (batch.logits[i]) {
|
||||
llama_seq_id cur_seq = batch.seq_id[i][0];
|
||||
last_batch_info[cur_seq] = i;
|
||||
}
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string token_to_piece(llama_token token, bool special) {
|
||||
std::string piece;
|
||||
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
|
||||
const int n_chars = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
|
||||
if (n_chars < 0) {
|
||||
piece.resize(-n_chars);
|
||||
int check = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
|
||||
GGML_ASSERT(check == -n_chars);
|
||||
}
|
||||
else {
|
||||
piece.resize(n_chars);
|
||||
}
|
||||
|
||||
return piece;
|
||||
}
|
||||
|
||||
void cleanup() {
|
||||
if (ctx) llama_free(ctx);
|
||||
if (model) llama_model_free(model);
|
||||
llama_backend_free();
|
||||
ctx = nullptr;
|
||||
model = nullptr;
|
||||
vocab = nullptr;
|
||||
}
|
||||
|
||||
~test_model_context() {
|
||||
cleanup();
|
||||
}
|
||||
};
|
||||
|
||||
static void test_backend_greedy_sampling(const char * model_path) {
|
||||
test_model_context test_ctx;
|
||||
|
||||
const int seq_id = 0;
|
||||
|
||||
struct llama_sampler_chain_params backend_sampler_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_sampler_params);
|
||||
|
||||
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_greedy());
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain }};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!test_ctx.decode({{seq_id, "Some"}})) {
|
||||
return;
|
||||
}
|
||||
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
|
||||
|
||||
llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx);
|
||||
printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
||||
token = llama_get_backend_sampled_token_ith(test_ctx.ctx, -1);
|
||||
printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
||||
for (int i = 0; i < 10; i++) {
|
||||
int32_t loop_idx = test_ctx.idx_for_seq(seq_id);
|
||||
llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, loop_idx);
|
||||
printf("Generation step %d: token id:%d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str());
|
||||
test_ctx.decode_token(token, 0);
|
||||
}
|
||||
}
|
||||
|
||||
static void test_backend_top_k_sampling(const char * model_path) {
|
||||
test_model_context test_ctx;
|
||||
|
||||
const int seq_id = 0;
|
||||
const int32_t k = 8;
|
||||
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
|
||||
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_top_k(k));
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain }};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!test_ctx.decode({{seq_id, "Hello"}})) {
|
||||
return;
|
||||
}
|
||||
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
|
||||
|
||||
float * logits = llama_get_backend_sampled_logits_ith(test_ctx.ctx, batch_idx);
|
||||
uint32_t n_logits = llama_get_backend_sampled_logits_count_ith(test_ctx.ctx, batch_idx);
|
||||
for (size_t i = 0; i < n_logits; ++i) {
|
||||
printf("top_k logit[%zu] = %.6f\n", i, logits[i]);
|
||||
}
|
||||
|
||||
// Sample using CPU sampler for verification that it is possible to do hybrid
|
||||
// sampling, first top_k on the backend and then dist on the CPU.
|
||||
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
|
||||
GGML_ASSERT(chain->iface->apply_ggml != nullptr);
|
||||
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_dist(18));
|
||||
llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx);
|
||||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
||||
printf("backend top-k hybrid sampling test PASSED\n");
|
||||
|
||||
llama_sampler_free(chain);
|
||||
}
|
||||
|
||||
static void test_backend_temp_sampling(const char * model_path) {
|
||||
test_model_context test_ctx;
|
||||
|
||||
const float temp_0 = 0.8f;
|
||||
struct llama_sampler_chain_params backend_chain_params_0 = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * backend_sampler_chain_0 = llama_sampler_chain_init(backend_chain_params_0);
|
||||
llama_sampler_chain_add(backend_sampler_chain_0, llama_sampler_backend_init_temp(temp_0));
|
||||
|
||||
const float temp_1 = 0.1f;
|
||||
struct llama_sampler_chain_params backend_chain_params_1 = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * backend_sampler_chain_1 = llama_sampler_chain_init(backend_chain_params_1);
|
||||
llama_sampler_chain_add(backend_sampler_chain_1, llama_sampler_backend_init_temp(temp_1));
|
||||
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {
|
||||
{ 0, backend_sampler_chain_0 },
|
||||
{ 1, backend_sampler_chain_1 }
|
||||
};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!test_ctx.decode({{0, "Some where over"}, {1, "Once upon a"}})) {
|
||||
return;
|
||||
}
|
||||
|
||||
int32_t batch_idx_0 = test_ctx.idx_for_seq(0);
|
||||
int32_t batch_idx_1 = test_ctx.idx_for_seq(1);
|
||||
|
||||
// Sample from sequence 0 using CPU sampler
|
||||
struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * chain_0 = llama_sampler_chain_init(chain_params_0);
|
||||
llama_sampler_chain_add(chain_0, llama_sampler_init_dist(18));
|
||||
|
||||
llama_token token_0 = llama_sampler_sample(chain_0, test_ctx.ctx, batch_idx_0);
|
||||
const std::string token_0_str = test_ctx.token_to_piece(token_0, false);
|
||||
printf("Sequence 0 sampled token id:%d, string: '%s'\n", token_0, token_0_str.c_str());
|
||||
GGML_ASSERT(token_0 >= 0 && token_0 < test_ctx.n_vocab);
|
||||
|
||||
// Sample from sequence 1 using CPU sampler
|
||||
struct llama_sampler_chain_params chain_params_1 = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * chain_1 = llama_sampler_chain_init(chain_params_1);
|
||||
llama_sampler_chain_add(chain_1, llama_sampler_init_dist(18));
|
||||
|
||||
llama_token token_1 = llama_sampler_sample(chain_1, test_ctx.ctx, batch_idx_1);
|
||||
const std::string token_1_str = test_ctx.token_to_piece(token_1, false);
|
||||
printf("Sequence 1 sampled token id:%d, string: '%s'\n", token_1, token_1_str.c_str());
|
||||
GGML_ASSERT(token_1 >= 0 && token_1 < test_ctx.n_vocab);
|
||||
|
||||
printf("backend temp sampling test PASSED\n");
|
||||
|
||||
llama_sampler_free(chain_0);
|
||||
llama_sampler_free(chain_1);
|
||||
}
|
||||
|
||||
static void test_backend_multi_sequence_sampling(const char * model_path) {
|
||||
test_model_context test_ctx;
|
||||
|
||||
struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * sampler_chain_0 = llama_sampler_chain_init(chain_params_0);
|
||||
llama_sampler_chain_add(sampler_chain_0, llama_sampler_backend_init_greedy());
|
||||
|
||||
struct llama_sampler_chain_params chain_params_1 = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * sampler_chain_1 = llama_sampler_chain_init(chain_params_1);
|
||||
llama_sampler_chain_add(sampler_chain_1, llama_sampler_backend_init_temp(0.8f));
|
||||
llama_sampler_chain_add(sampler_chain_1, llama_sampler_backend_init_greedy());
|
||||
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {
|
||||
{ 0, sampler_chain_0 },
|
||||
{ 1, sampler_chain_1 }
|
||||
};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::map<llama_seq_id, std::string> prompts = {
|
||||
{0, "Hello"},
|
||||
{1, "Some"}
|
||||
};
|
||||
|
||||
if (!test_ctx.decode(prompts)) {
|
||||
return;
|
||||
}
|
||||
|
||||
int32_t batch_idx_0 = test_ctx.idx_for_seq(0);
|
||||
llama_token seq0_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx_0);
|
||||
const std::string seq0_token_str = test_ctx.token_to_piece(seq0_token, false);
|
||||
printf("Seq 0 sampled token id=%d, string='%s'\n", seq0_token, seq0_token_str.c_str());
|
||||
GGML_ASSERT(seq0_token >= 0 && seq0_token < test_ctx.n_vocab);
|
||||
|
||||
int32_t batch_idx_1 = test_ctx.idx_for_seq(1);
|
||||
llama_token seq1_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx_1);
|
||||
const std::string seq1_token_str = test_ctx.token_to_piece(seq1_token, false);
|
||||
printf("Seq 1 sampled token id=%d, string='%s'\n", seq1_token, seq1_token_str.c_str());
|
||||
GGML_ASSERT(seq1_token >= 0 && seq1_token < test_ctx.n_vocab);
|
||||
|
||||
// Generate tokens for each sequence
|
||||
printf("\nMulti-sequence generation:\n");
|
||||
for (int step = 0; step < 4; step++) {
|
||||
std::map<llama_seq_id, llama_token> tokens;
|
||||
|
||||
for (llama_seq_id seq_id : {0, 1}) {
|
||||
int32_t idx = test_ctx.idx_for_seq(seq_id);
|
||||
llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, idx);
|
||||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
||||
printf(" Seq %d, step %d: token id=%d, string='%s'\n", seq_id, step, token, token_str.c_str());
|
||||
tokens[seq_id] = token;
|
||||
}
|
||||
|
||||
// Decode all tokens in a single batch
|
||||
if (!test_ctx.decode_tokens(tokens)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
printf("backend multi-sequence sampling test PASSED\n");
|
||||
}
|
||||
|
||||
static void test_backend_dist_sampling(const char * model_path) {
|
||||
test_model_context test_ctx;
|
||||
|
||||
const int32_t seed = 88;
|
||||
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
|
||||
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_dist(seed));
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ 0, backend_sampler_chain }};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!test_ctx.decode({{0, "Hello"}})) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, test_ctx.idx_for_seq(0));
|
||||
printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
|
||||
token = llama_get_backend_sampled_token_ith(test_ctx.ctx, -1);
|
||||
printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
}
|
||||
|
||||
static void test_backend_dist_sampling_and_cpu(const char * model_path) {
|
||||
test_model_context test_ctx;
|
||||
|
||||
const int seq_id = 0;
|
||||
const int32_t seed = 88;
|
||||
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
|
||||
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_dist(seed));
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain }};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!test_ctx.decode({{seq_id, "Hello"}})) {
|
||||
return;
|
||||
}
|
||||
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
|
||||
|
||||
// Sample using CPU sampler
|
||||
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_dist(18));
|
||||
|
||||
llama_token backend_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx);
|
||||
llama_token cpu_token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx);
|
||||
GGML_ASSERT(backend_token == cpu_token);
|
||||
}
|
||||
|
||||
static void test_backend_logit_bias_sampling(const char * model_path) {
|
||||
test_model_context test_ctx;
|
||||
|
||||
// Calling setup_model to ensure vocab is loaded and can be accessed
|
||||
if (!test_ctx.setup_model(model_path)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int seq_id = 0;
|
||||
|
||||
// Create the logit biases vector.
|
||||
std::vector<llama_logit_bias> logit_bias;
|
||||
|
||||
// Get the token for the piece "World".
|
||||
const std::string piece = "World";
|
||||
std::vector<llama_token> tokens(16);
|
||||
llama_tokenize(test_ctx.vocab, piece.c_str(), piece.size(), tokens.data(), tokens.size(), false, false);
|
||||
llama_token bias_token = tokens[0];
|
||||
logit_bias.push_back({ bias_token, +100.0f });
|
||||
printf("biasing token piece '%s' -> token id %d\n", piece.c_str(), bias_token);
|
||||
|
||||
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
|
||||
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_logit_bias(
|
||||
llama_vocab_n_tokens(test_ctx.vocab),
|
||||
logit_bias.size(),
|
||||
logit_bias.data()));
|
||||
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_dist(88));
|
||||
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {
|
||||
{ seq_id, backend_sampler_chain },
|
||||
};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!test_ctx.decode({{seq_id, "Hello"}})) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_token backend_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, test_ctx.idx_for_seq(seq_id));
|
||||
const std::string backend_token_str = test_ctx.token_to_piece(backend_token, false);
|
||||
printf("logit bias sampled token = %d, string='%s'\n", backend_token, backend_token_str.c_str());
|
||||
GGML_ASSERT(backend_token == bias_token);
|
||||
}
|
||||
|
||||
static void test_backend_set_sampler(const char * model_path) {
|
||||
test_model_context test_ctx;
|
||||
|
||||
const int32_t seed = 88;
|
||||
const int seq_id = 0;
|
||||
struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * backend_sampler_chain = llama_sampler_chain_init(backend_chain_params);
|
||||
llama_sampler_chain_add(backend_sampler_chain, llama_sampler_backend_init_dist(seed));
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {{ seq_id, backend_sampler_chain }};
|
||||
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs)) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!test_ctx.decode({{seq_id, "Hello"}})) {
|
||||
return;
|
||||
}
|
||||
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
|
||||
|
||||
// Sample using backend sampler configured above
|
||||
llama_token backend_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx);
|
||||
const std::string backend_token_str = test_ctx.token_to_piece(backend_token, false);
|
||||
printf("dist sampled token = %d, string='%s'\n", backend_token, backend_token_str.c_str());
|
||||
|
||||
// Now clear the backend sampler for this sequence.
|
||||
llama_set_backend_sampler(test_ctx.ctx, seq_id, nullptr);
|
||||
printf("Cleared backend sampler for seq_id %d\n", seq_id);
|
||||
|
||||
// Sample using CPU sampler
|
||||
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_dist(18));
|
||||
|
||||
std::map<llama_seq_id, llama_token> tokens = { { seq_id, backend_token}, };
|
||||
if (!test_ctx.decode_tokens(tokens)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Should not have any sampled token or probs after clearing the backend sampler.
|
||||
const int32_t idx = test_ctx.idx_for_seq(seq_id);
|
||||
GGML_ASSERT(llama_get_backend_sampled_token_ith(test_ctx.ctx, idx) == LLAMA_TOKEN_NULL);
|
||||
GGML_ASSERT(llama_get_backend_sampled_probs_ith(test_ctx.ctx, idx) == nullptr);
|
||||
|
||||
// Sample the token using the CPU sampler chain.
|
||||
llama_token token2 = llama_sampler_sample(chain, test_ctx.ctx, seq_id);
|
||||
const std::string token2_str = test_ctx.token_to_piece(token2, false);
|
||||
printf("CPU sampled token after clearing backend sampler: id=%d, string='%s'\n", token2, token2_str.c_str());
|
||||
std::map<llama_seq_id, llama_token> tokens2 = { { seq_id, token2}, };
|
||||
|
||||
// Set a new backend sampler for the sequence.
|
||||
struct llama_sampler_chain_params new_backend_chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * new_backend_sampler_chain = llama_sampler_chain_init(new_backend_chain_params);
|
||||
llama_sampler_chain_add(new_backend_sampler_chain, llama_sampler_backend_init_top_k(20));
|
||||
llama_sampler_chain_add(new_backend_sampler_chain, llama_sampler_backend_init_dist(seed));
|
||||
llama_set_backend_sampler(test_ctx.ctx, seq_id, new_backend_sampler_chain);
|
||||
|
||||
if (!test_ctx.decode_tokens(tokens2)) {
|
||||
return;
|
||||
}
|
||||
|
||||
llama_token new_backend_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, test_ctx.idx_for_seq(seq_id));
|
||||
const std::string new_backend_token_str = test_ctx.token_to_piece(new_backend_token, false);
|
||||
printf("dist sampled token = %d, string='%s'\n", new_backend_token, new_backend_token_str.c_str());
|
||||
}
|
||||
|
||||
struct backend_test_case {
|
||||
const char * name;
|
||||
void (*fn)(const char *);
|
||||
bool enabled_by_default;
|
||||
};
|
||||
|
||||
static const backend_test_case BACKEND_TESTS[] = {
|
||||
{ "greedy", test_backend_greedy_sampling, true },
|
||||
{ "logit_bias", test_backend_logit_bias_sampling, true },
|
||||
{ "temp", test_backend_temp_sampling, true },
|
||||
{ "top_k", test_backend_top_k_sampling, false },
|
||||
{ "multi_sequence", test_backend_multi_sequence_sampling, false },
|
||||
{ "dist", test_backend_dist_sampling, false },
|
||||
{ "dist_and_cpu", test_backend_dist_sampling_and_cpu, false },
|
||||
{ "set_sampler", test_backend_set_sampler, true },
|
||||
};
|
||||
|
||||
struct backend_cli_args {
|
||||
const char * model = nullptr;
|
||||
const char * test = nullptr;
|
||||
};
|
||||
|
||||
static backend_cli_args parse_backend_cli(int argc, char ** argv) {
|
||||
backend_cli_args out;
|
||||
|
||||
for (int i = 1; i < argc; ++i) {
|
||||
const char * arg = argv[i];
|
||||
|
||||
if (std::strcmp(arg, "--test") == 0) {
|
||||
if (i + 1 >= argc) {
|
||||
fprintf(stderr, "--test expects a value\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
out.test = argv[++i];
|
||||
continue;
|
||||
}
|
||||
if (std::strncmp(arg, "--test=", 7) == 0) {
|
||||
out.test = arg + 7;
|
||||
continue;
|
||||
}
|
||||
if (std::strcmp(arg, "--model") == 0) {
|
||||
if (i + 1 >= argc) {
|
||||
fprintf(stderr, "--model expects a value\n");
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
out.model = argv[++i];
|
||||
continue;
|
||||
}
|
||||
if (std::strncmp(arg, "--model=", 8) == 0) {
|
||||
out.model = arg + 8;
|
||||
continue;
|
||||
}
|
||||
if (!out.model) {
|
||||
out.model = arg;
|
||||
continue;
|
||||
}
|
||||
|
||||
fprintf(stderr, "Unexpected argument: %s\n", arg);
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
static std::vector<const backend_test_case *> collect_tests_to_run(const char * requested) {
|
||||
std::vector<const backend_test_case *> selected;
|
||||
|
||||
if (requested != nullptr) {
|
||||
for (const auto & test : BACKEND_TESTS) {
|
||||
if (std::strcmp(test.name, requested) == 0) {
|
||||
selected.push_back(&test);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (selected.empty()) {
|
||||
fprintf(stderr, "Unknown test '%s'. Available tests:\n", requested);
|
||||
for (const auto & test : BACKEND_TESTS) {
|
||||
fprintf(stderr, " %s\n", test.name);
|
||||
}
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
} else {
|
||||
for (const auto & test : BACKEND_TESTS) {
|
||||
if (test.enabled_by_default) {
|
||||
selected.push_back(&test);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (selected.empty()) {
|
||||
fprintf(stderr, "No backend sampling tests selected. Use --test=<name> to pick one.\n");
|
||||
}
|
||||
|
||||
return selected;
|
||||
}
|
||||
|
||||
static void run_tests(const std::vector<const backend_test_case *> & tests, const char * model_path) {
|
||||
for (const auto * test : tests) {
|
||||
fprintf(stderr, "\n=== %s ===\n", test->name);
|
||||
test->fn(model_path);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
int main(int argc, char *argv[] ) {
|
||||
const backend_cli_args args = parse_backend_cli(argc, argv);
|
||||
|
||||
std::array<char *, 2> model_argv { argv[0], const_cast<char *>(args.model) };
|
||||
const int model_argc = args.model ? 2 : 1;
|
||||
char * model_path = get_model_or_exit(model_argc, model_argv.data());
|
||||
|
||||
auto * file = fopen(model_path, "r");
|
||||
if (file == nullptr) {
|
||||
fprintf(stderr, "no model at '%s' found\n", model_path);
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
fprintf(stderr, "using '%s'\n", model_path);
|
||||
fclose(file);
|
||||
|
||||
ggml_time_init();
|
||||
|
||||
const std::vector<const backend_test_case *> tests = collect_tests_to_run(args.test);
|
||||
if (!tests.empty()) {
|
||||
run_tests(tests, model_path);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
Loading…
Reference in New Issue